Beispiel #1
0
    def __init__(self, env, modalities):
        super(VisionSensor, self).__init__(env)
        self.modalities = modalities
        self.raw_modalities = self.get_raw_modalities(modalities)
        self.image_width = self.config.get('image_width', 128)
        self.image_height = self.config.get('image_height', 128)

        self.depth_noise_rate = self.config.get('depth_noise_rate', 0.0)
        self.depth_low = self.config.get('depth_low', 0.5)
        self.depth_high = self.config.get('depth_high', 5.0)

        self.noise_model = DropoutSensorNoise(env)
        self.noise_model.set_noise_rate(self.depth_noise_rate)
        self.noise_model.set_noise_value(0.0)

        if 'rgb_filled' in modalities:
            try:
                import torch.nn as nn
                import torch
                from torchvision import transforms
                from gibson2.learn.completion import CompletionNet
            except ImportError:
                raise Exception(
                    'Trying to use rgb_filled ("the goggle"), but torch is not installed. Try "pip install torch torchvision".'
                )

            self.comp = CompletionNet(norm=nn.BatchNorm2d, nf=64)
            self.comp = torch.nn.DataParallel(self.comp).cuda()
            self.comp.load_state_dict(
                torch.load(
                    os.path.join(gibson2.assets_path, 'networks',
                                 'model.pth')))
            self.comp.eval()
Beispiel #2
0
    def load_model(self):
        comp = CompletionNet(norm=nn.BatchNorm2d, nf=64)
        comp = nn.DataParallel(comp).cuda()
        comp.load_state_dict(
            torch.load(os.path.join(assets_file_dir, "unfiller_256.pth")))

        model = comp.module
        model.eval()
        print(model)
        return model
Beispiel #3
0
    def load(self):
        super(NavigateEnv, self).load()
        self.initial_pos = np.array(self.config.get('initial_pos', [0, 0, 0]))
        self.initial_orn = np.array(self.config.get('initial_orn', [0, 0, 0]))

        self.target_pos = np.array(self.config.get('target_pos', [5, 5, 0]))
        self.target_orn = np.array(self.config.get('target_orn', [0, 0, 0]))

        self.additional_states_dim = self.config['additional_states_dim']

        # termination condition
        self.dist_tol = self.config.get('dist_tol', 0.2)
        self.max_step = self.config.get('max_step', float('inf'))

        # reward
        self.success_reward = self.config.get('success_reward', 10.0)
        self.slack_reward = self.config.get('slack_reward', -0.01)

        # reward weight
        self.potential_reward_weight = self.config.get(
            'potential_reward_weight', 10.0)
        self.electricity_reward_weight = self.config.get(
            'electricity_reward_weight', 0.0)
        self.stall_torque_reward_weight = self.config.get(
            'stall_torque_reward_weight', 0.0)
        self.collision_reward_weight = self.config.get(
            'collision_reward_weight', 0.0)

        # discount factor
        self.discount_factor = self.config.get('discount_factor', 1.0)
        self.output = self.config['output']

        self.sensor_dim = self.robots[0].sensor_dim + self.additional_states_dim
        self.action_dim = self.robots[0].action_dim

        observation_space = OrderedDict()
        if 'sensor' in self.output:
            self.sensor_space = gym.spaces.Box(low=-np.inf,
                                               high=np.inf,
                                               shape=(self.sensor_dim, ),
                                               dtype=np.float32)
            observation_space['sensor'] = self.sensor_space
        if 'pointgoal' in self.output:
            self.pointgoal_space = gym.spaces.Box(low=-np.inf,
                                                  high=np.inf,
                                                  shape=(2, ),
                                                  dtype=np.float32)
            observation_space['pointgoal'] = self.pointgoal_space
        if 'rgb' in self.output:
            self.rgb_space = gym.spaces.Box(low=0.0,
                                            high=1.0,
                                            shape=(self.config['resolution'],
                                                   self.config['resolution'],
                                                   3),
                                            dtype=np.float32)
            observation_space['rgb'] = self.rgb_space
        if 'depth' in self.output:
            self.depth_space = gym.spaces.Box(low=0.0,
                                              high=1.0,
                                              shape=(self.config['resolution'],
                                                     self.config['resolution'],
                                                     1),
                                              dtype=np.float32)
            observation_space['depth'] = self.depth_space
        if 'rgb_filled' in self.output:  # use filler
            self.comp = CompletionNet(norm=nn.BatchNorm2d, nf=64)
            self.comp = torch.nn.DataParallel(self.comp).cuda()
            self.comp.load_state_dict(
                torch.load(
                    os.path.join(gibson2.assets_path, 'networks',
                                 'model.pth')))
            self.comp.eval()

        self.observation_space = gym.spaces.Dict(observation_space)
        self.action_space = self.robots[0].action_space

        # variable initialization
        self.current_episode = 0

        # add visual objects
        self.visual_object_at_initial_target_pos = self.config.get(
            'visual_object_at_initial_target_pos', False)

        if self.visual_object_at_initial_target_pos:
            self.initial_pos_vis_obj = VisualObject(rgba_color=[1, 0, 0, 0.5])
            self.target_pos_vis_obj = VisualObject(rgba_color=[0, 0, 1, 0.5])
            self.initial_pos_vis_obj.load()
            if self.config.get('target_visual_object_visible_to_agent', False):
                self.simulator.import_object(self.target_pos_vis_obj)
            else:
                self.target_pos_vis_obj.load()
Beispiel #4
0
class NavigateEnv(BaseEnv):
    def __init__(
        self,
        config_file,
        mode='headless',
        action_timestep=1 / 10.0,
        physics_timestep=1 / 240.0,
        automatic_reset=False,
        device_idx=0,
    ):
        super(NavigateEnv, self).__init__(config_file=config_file,
                                          mode=mode,
                                          device_idx=device_idx)
        self.automatic_reset = automatic_reset

        # simulation
        self.mode = mode
        self.action_timestep = action_timestep
        self.physics_timestep = physics_timestep
        self.simulator.set_timestep(physics_timestep)
        self.simulator_loop = int(self.action_timestep /
                                  self.simulator.timestep)

    def load(self):
        super(NavigateEnv, self).load()
        self.initial_pos = np.array(self.config.get('initial_pos', [0, 0, 0]))
        self.initial_orn = np.array(self.config.get('initial_orn', [0, 0, 0]))

        self.target_pos = np.array(self.config.get('target_pos', [5, 5, 0]))
        self.target_orn = np.array(self.config.get('target_orn', [0, 0, 0]))

        self.additional_states_dim = self.config['additional_states_dim']

        # termination condition
        self.dist_tol = self.config.get('dist_tol', 0.2)
        self.max_step = self.config.get('max_step', float('inf'))

        # reward
        self.success_reward = self.config.get('success_reward', 10.0)
        self.slack_reward = self.config.get('slack_reward', -0.01)

        # reward weight
        self.potential_reward_weight = self.config.get(
            'potential_reward_weight', 10.0)
        self.electricity_reward_weight = self.config.get(
            'electricity_reward_weight', 0.0)
        self.stall_torque_reward_weight = self.config.get(
            'stall_torque_reward_weight', 0.0)
        self.collision_reward_weight = self.config.get(
            'collision_reward_weight', 0.0)

        # discount factor
        self.discount_factor = self.config.get('discount_factor', 1.0)
        self.output = self.config['output']

        self.sensor_dim = self.robots[0].sensor_dim + self.additional_states_dim
        self.action_dim = self.robots[0].action_dim

        observation_space = OrderedDict()
        if 'sensor' in self.output:
            self.sensor_space = gym.spaces.Box(low=-np.inf,
                                               high=np.inf,
                                               shape=(self.sensor_dim, ),
                                               dtype=np.float32)
            observation_space['sensor'] = self.sensor_space
        if 'pointgoal' in self.output:
            self.pointgoal_space = gym.spaces.Box(low=-np.inf,
                                                  high=np.inf,
                                                  shape=(2, ),
                                                  dtype=np.float32)
            observation_space['pointgoal'] = self.pointgoal_space
        if 'rgb' in self.output:
            self.rgb_space = gym.spaces.Box(low=0.0,
                                            high=1.0,
                                            shape=(self.config['resolution'],
                                                   self.config['resolution'],
                                                   3),
                                            dtype=np.float32)
            observation_space['rgb'] = self.rgb_space
        if 'depth' in self.output:
            self.depth_space = gym.spaces.Box(low=0.0,
                                              high=1.0,
                                              shape=(self.config['resolution'],
                                                     self.config['resolution'],
                                                     1),
                                              dtype=np.float32)
            observation_space['depth'] = self.depth_space
        if 'rgb_filled' in self.output:  # use filler
            self.comp = CompletionNet(norm=nn.BatchNorm2d, nf=64)
            self.comp = torch.nn.DataParallel(self.comp).cuda()
            self.comp.load_state_dict(
                torch.load(
                    os.path.join(gibson2.assets_path, 'networks',
                                 'model.pth')))
            self.comp.eval()

        self.observation_space = gym.spaces.Dict(observation_space)
        self.action_space = self.robots[0].action_space

        # variable initialization
        self.current_episode = 0

        # add visual objects
        self.visual_object_at_initial_target_pos = self.config.get(
            'visual_object_at_initial_target_pos', False)

        if self.visual_object_at_initial_target_pos:
            self.initial_pos_vis_obj = VisualObject(rgba_color=[1, 0, 0, 0.5])
            self.target_pos_vis_obj = VisualObject(rgba_color=[0, 0, 1, 0.5])
            self.initial_pos_vis_obj.load()
            if self.config.get('target_visual_object_visible_to_agent', False):
                self.simulator.import_object(self.target_pos_vis_obj)
            else:
                self.target_pos_vis_obj.load()

    def reload(self, config_file):
        super(NavigateEnv, self).reload(config_file)
        self.initial_pos = np.array(self.config.get('initial_pos', [0, 0, 0]))
        self.initial_orn = np.array(self.config.get('initial_orn', [0, 0, 0]))

        self.target_pos = np.array(self.config.get('target_pos', [5, 5, 0]))
        self.target_orn = np.array(self.config.get('target_orn', [0, 0, 0]))

        self.additional_states_dim = self.config['additional_states_dim']

        # termination condition
        self.dist_tol = self.config.get('dist_tol', 0.5)
        self.max_step = self.config.get('max_step', float('inf'))

        # reward
        self.terminal_reward = self.config.get('terminal_reward', 0.0)
        self.electricity_cost = self.config.get('electricity_cost', 0.0)
        self.stall_torque_cost = self.config.get('stall_torque_cost', 0.0)
        self.collision_cost = self.config.get('collision_cost', 0.0)
        self.discount_factor = self.config.get('discount_factor', 1.0)
        self.output = self.config['output']

        self.sensor_dim = self.additional_states_dim
        self.action_dim = self.robots[0].action_dim

        # self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.sensor_dim,), dtype=np.float64)
        observation_space = OrderedDict()
        if 'sensor' in self.output:
            self.sensor_space = gym.spaces.Box(low=-np.inf,
                                               high=np.inf,
                                               shape=(self.sensor_dim, ),
                                               dtype=np.float32)
            observation_space['sensor'] = self.sensor_space
        if 'rgb' in self.output:
            self.rgb_space = gym.spaces.Box(low=0.0,
                                            high=1.0,
                                            shape=(self.config['resolution'],
                                                   self.config['resolution'],
                                                   3),
                                            dtype=np.float32)
            observation_space['rgb'] = self.rgb_space
        if 'depth' in self.output:
            self.depth_space = gym.spaces.Box(low=0.0,
                                              high=1.0,
                                              shape=(self.config['resolution'],
                                                     self.config['resolution'],
                                                     1),
                                              dtype=np.float32)
            observation_space['depth'] = self.depth_space
        if 'rgb_filled' in self.output:  # use filler
            self.comp = CompletionNet(norm=nn.BatchNorm2d, nf=64)
            self.comp = torch.nn.DataParallel(self.comp).cuda()
            self.comp.load_state_dict(
                torch.load(
                    os.path.join(gibson2.assets_path, 'networks',
                                 'model.pth')))
            self.comp.eval()
        if 'pointgoal' in self.output:
            observation_space['pointgoal'] = gym.spaces.Box(low=-np.inf,
                                                            high=np.inf,
                                                            shape=(2, ),
                                                            dtype=np.float32)

        self.observation_space = gym.spaces.Dict(observation_space)
        self.action_space = self.robots[0].action_space

        self.visual_object_at_initial_target_pos = self.config.get(
            'visual_object_at_initial_target_pos', False)
        if self.visual_object_at_initial_target_pos:
            self.initial_pos_vis_obj = VisualObject(rgba_color=[1, 0, 0, 0.5])
            self.target_pos_vis_obj = VisualObject(rgba_color=[0, 0, 1, 0.5])
            self.initial_pos_vis_obj.load()
            if self.config.get('target_visual_object_visible_to_agent', False):
                self.simulator.import_object(self.target_pos_vis_obj)
            else:
                self.target_pos_vis_obj.load()

    def get_additional_states(self):
        relative_position = self.target_pos - self.robots[0].get_position()
        # rotate relative position back to body point of view
        additional_states = rotate_vector_3d(relative_position,
                                             *self.robots[0].get_rpy())

        if self.config['task'] == 'reaching':
            end_effector_pos = self.robots[0].get_end_effector_position(
            ) - self.robots[0].get_position()
            end_effector_pos = rotate_vector_3d(end_effector_pos,
                                                *self.robots[0].get_rpy())
            additional_states = np.concatenate(
                (additional_states, end_effector_pos))
        assert len(
            additional_states
        ) == self.additional_states_dim, 'additional states dimension mismatch'

        return additional_states
        """
        relative_position = self.target_pos - self.robots[0].get_position()
        # rotate relative position back to body point of view
        relative_position_odom = rotate_vector_3d(relative_position, *self.robots[0].get_rpy())
        # the angle between the direction the agent is facing and the direction to the target position
        delta_yaw = np.arctan2(relative_position_odom[1], relative_position_odom[0])
        additional_states = np.concatenate((relative_position,
                                            relative_position_odom,
                                            [np.sin(delta_yaw), np.cos(delta_yaw)]))
        if self.config['task'] == 'reaching':
            # get end effector information

            end_effector_pos = self.robots[0].get_end_effector_position() - self.robots[0].get_position()
            end_effector_pos = rotate_vector_3d(end_effector_pos, *self.robots[0].get_rpy())
            additional_states = np.concatenate((additional_states, end_effector_pos))

        assert len(additional_states) == self.additional_states_dim, 'additional states dimension mismatch'
        return additional_states
        """

    def get_state(self, collision_links=[]):
        # calculate state
        # sensor_state = self.robots[0].calc_state()
        # sensor_state = np.concatenate((sensor_state, self.get_additional_states()))
        sensor_state = self.get_additional_states()

        state = OrderedDict()
        if 'sensor' in self.output:
            state['sensor'] = sensor_state
        if 'pointgoal' in self.output:
            state['pointgoal'] = sensor_state[:2]
        if 'rgb' in self.output:
            state['rgb'] = self.simulator.renderer.render_robot_cameras(
                modes=('rgb'))[0][:, :, :3]
        if 'depth' in self.output:
            depth = -self.simulator.renderer.render_robot_cameras(
                modes=('3d'))[0][:, :, 2:3]
            state['depth'] = depth
        if 'normal' in self.output:
            state['normal'] = self.simulator.renderer.render_robot_cameras(
                modes='normal')
        if 'seg' in self.output:
            state['seg'] = self.simulator.renderer.render_robot_cameras(
                modes='seg')
        if 'rgb_filled' in self.output:
            with torch.no_grad():
                tensor = transforms.ToTensor()(
                    (state['rgb'] * 255).astype(np.uint8)).cuda()
                rgb_filled = self.comp(tensor[None, :, :, :])[0].permute(
                    1, 2, 0).cpu().numpy()
                state['rgb_filled'] = rgb_filled
        if 'bump' in self.output:
            state[
                'bump'] = -1 in collision_links  # check collision for baselink, it might vary for different robots

        if 'pointgoal' in self.output:
            state['pointgoal'] = sensor_state[:2]

        if 'scan' in self.output:
            assert 'scan_link' in self.robots[
                0].parts, "Requested scan but no scan_link"
            pose_camera = self.robots[0].parts['scan_link'].get_pose()
            n_rays_per_horizontal = 128  # Number of rays along one horizontal scan/slice

            n_vertical_beams = 9
            angle = np.arange(0, 2 * np.pi,
                              2 * np.pi / float(n_rays_per_horizontal))
            elev_bottom_angle = -30. * np.pi / 180.
            elev_top_angle = 10. * np.pi / 180.
            elev_angle = np.arange(elev_bottom_angle, elev_top_angle,
                                   (elev_top_angle - elev_bottom_angle) /
                                   float(n_vertical_beams))
            orig_offset = np.vstack([
                np.vstack([
                    np.cos(angle),
                    np.sin(angle),
                    np.repeat(np.tan(elev_ang), angle.shape)
                ]).T for elev_ang in elev_angle
            ])
            transform_matrix = quat2mat([
                pose_camera[-1], pose_camera[3], pose_camera[4], pose_camera[5]
            ])
            offset = orig_offset.dot(np.linalg.inv(transform_matrix))
            pose_camera = pose_camera[None, :3].repeat(n_rays_per_horizontal *
                                                       n_vertical_beams,
                                                       axis=0)

            results = p.rayTestBatch(pose_camera, pose_camera + offset * 30)
            hit = np.array([item[0] for item in results])
            dist = np.array([item[2] for item in results])
            dist[dist >= 1 - 1e-5] = np.nan
            dist[dist < 0.1 / 30] = np.nan

            dist[hit == self.robots[0].robot_ids[0]] = np.nan
            dist[hit == -1] = np.nan
            dist *= 30

            xyz = dist[:, np.newaxis] * orig_offset
            xyz = xyz[np.equal(np.isnan(xyz), False)]  # Remove nans
            #print(xyz.shape)
            xyz = xyz.reshape(xyz.shape[0] // 3, -1)
            state['scan'] = xyz

        return state

    def run_simulation(self):
        collision_links = []
        for _ in range(self.simulator_loop):
            self.simulator_step()
            collision_links += [
                item[3] for item in p.getContactPoints(
                    bodyA=self.robots[0].robot_ids[0])
            ]
        collision_links = np.unique(collision_links)
        return collision_links

    def get_position_of_interest(self):
        if self.config['task'] == 'pointgoal':
            return self.robots[0].get_position()
        elif self.config['task'] == 'reaching':
            return self.robots[0].get_end_effector_position()

    def get_potential(self):
        return l2_distance(self.target_pos, self.get_position_of_interest())

    def get_reward(self, collision_links):
        reward = self.slack_reward  # |slack_reward| = 0.01 per step

        new_normalized_potential = self.get_potential(
        ) / self.initial_potential

        potential_reward = self.normalized_potential - new_normalized_potential
        reward += potential_reward * self.potential_reward_weight  # |potential_reward| ~= 0.1 per step
        self.normalized_potential = new_normalized_potential

        # electricity_reward = np.abs(self.robots[0].joint_speeds * self.robots[0].joint_torque).mean().item()
        electricity_reward = 0.0
        reward += electricity_reward * self.electricity_reward_weight  # |electricity_reward| ~= 0.05 per step

        # stall_torque_reward = np.square(self.robots[0].joint_torque).mean()
        stall_torque_reward = 0.0
        reward += stall_torque_reward * self.stall_torque_reward_weight  # |stall_torque_reward| ~= 0.05 per step

        collision_reward = -1.0 if -1 in collision_links else 0.0
        reward += collision_reward * self.collision_reward_weight  # |collision_reward| ~= 1.0 per step if collision

        # goal reached
        if l2_distance(self.target_pos,
                       self.get_position_of_interest()) < self.dist_tol:
            reward += self.success_reward  # |success_reward| = 10.0 per step

        return reward

    def get_termination(self):
        self.current_step += 1
        done, info = False, {}

        # goal reached
        if l2_distance(self.target_pos,
                       self.get_position_of_interest()) < self.dist_tol:
            # print('goal')
            done = True
            info['success'] = True
        # robot flips over
        elif self.robots[0].get_position()[2] > 0.1:
            # print('death')
            done = True
            info['success'] = False
        # time out
        elif self.current_step >= self.max_step:
            # print('timeout')
            done = True
            info['success'] = False

        return done, info

    def step(self, action):
        self.robots[0].apply_action(action)
        collision_links = self.run_simulation()
        state = self.get_state(collision_links)
        reward = self.get_reward(collision_links)
        done, info = self.get_termination()

        if done and self.automatic_reset:
            state = self.reset()
        return state, reward, done, info

    def reset_initial_and_target_pos(self):
        self.robots[0].set_position(pos=self.initial_pos)
        self.robots[0].set_orientation(
            orn=quatToXYZW(euler2quat(*self.initial_orn), 'wxyz'))

    def reset(self):
        self.robots[0].robot_specific_reset()
        self.reset_initial_and_target_pos()
        self.initial_potential = self.get_potential()
        self.normalized_potential = 1.0
        self.current_step = 0

        # set position for visual objects
        if self.visual_object_at_initial_target_pos:
            self.initial_pos_vis_obj.set_position(self.initial_pos)
            self.target_pos_vis_obj.set_position(self.target_pos)

        state = self.get_state()
        return state
Beispiel #5
0
    def load_observation_space(self):
        """
        Load observation space
        """
        self.output = self.config['output']
        self.image_width = self.config.get('image_width', 128)
        self.image_height = self.config.get('image_height', 128)
        observation_space = OrderedDict()
        if 'sensor' in self.output:
            self.sensor_dim = self.additional_states_dim
            self.sensor_space = gym.spaces.Box(low=-np.inf,
                                               high=np.inf,
                                               shape=(self.sensor_dim,),
                                               dtype=np.float32)
            observation_space['sensor'] = self.sensor_space
        if 'rgb' in self.output:
            self.rgb_space = gym.spaces.Box(low=0.0,
                                            high=1.0,
                                            shape=(self.image_height, self.image_width, 3),
                                            dtype=np.float32)
            observation_space['rgb'] = self.rgb_space
        if 'depth' in self.output:
            self.depth_noise_rate = self.config.get('depth_noise_rate', 0.0)
            self.depth_low = self.config.get('depth_low', 0.5)
            self.depth_high = self.config.get('depth_high', 5.0)
            self.depth_space = gym.spaces.Box(low=0.0,
                                              high=1.0,
                                              shape=(self.image_height, self.image_width, 1),
                                              dtype=np.float32)
            observation_space['depth'] = self.depth_space
        if 'rgbd' in self.output:
            self.rgbd_space = gym.spaces.Box(low=0.0,
                                             high=1.0,
                                             shape=(self.image_height, self.image_width, 4),
                                             dtype=np.float32)
            observation_space['rgbd'] = self.rgbd_space
        if 'seg' in self.output:
            self.seg_space = gym.spaces.Box(low=0.0,
                                            high=1.0,
                                            shape=(self.image_height, self.image_width, 1),
                                            dtype=np.float32)
            observation_space['seg'] = self.seg_space
        if 'scan' in self.output:
            self.scan_noise_rate = self.config.get('scan_noise_rate', 0.0)
            self.n_horizontal_rays = self.config.get('n_horizontal_rays', 128)
            self.n_vertical_beams = self.config.get('n_vertical_beams', 1)
            assert self.n_vertical_beams == 1, 'scan can only handle one vertical beam for now'
            self.laser_linear_range = self.config.get('laser_linear_range', 10.0)
            self.laser_angular_range = self.config.get('laser_angular_range', 180.0)
            self.min_laser_dist = self.config.get('min_laser_dist', 0.05)
            self.laser_link_name = self.config.get('laser_link_name', 'scan_link')
            self.scan_space = gym.spaces.Box(low=0.0,
                                             high=1.0,
                                             shape=(self.n_horizontal_rays * self.n_vertical_beams, 1),
                                             dtype=np.float32)
            observation_space['scan'] = self.scan_space
        if 'rgb_filled' in self.output:  # use filler
            try:
                import torch.nn as nn
                import torch
                from torchvision import datasets, transforms
                from gibson2.learn.completion import CompletionNet
            except:
                raise Exception('Trying to use rgb_filled ("the goggle"), but torch is not installed. Try "pip install torch torchvision".')

            self.comp = CompletionNet(norm=nn.BatchNorm2d, nf=64)
            self.comp = torch.nn.DataParallel(self.comp).cuda()
            self.comp.load_state_dict(
                torch.load(os.path.join(gibson2.assets_path, 'networks', 'model.pth')))
            self.comp.eval()

        self.observation_space = gym.spaces.Dict(observation_space)
Beispiel #6
0
class NavigateEnv(BaseEnv):
    """
    We define navigation environments following Anderson, Peter, et al. 'On evaluation of embodied navigation agents.'
    arXiv preprint arXiv:1807.06757 (2018). (https://arxiv.org/pdf/1807.06757.pdf)

    """
    def __init__(
            self,
            config_file,
            model_id=None,
            mode='headless',
            action_timestep=1 / 10.0,
            physics_timestep=1 / 240.0,
            automatic_reset=False,
            device_idx=0,
            render_to_tensor=False
    ):
        """
        :param config_file: config_file path
        :param model_id: override model_id in config file
        :param mode: headless or gui mode
        :param action_timestep: environment executes action per action_timestep second
        :param physics_timestep: physics timestep for pybullet
        :param automatic_reset: whether to automatic reset after an episode finishes
        :param device_idx: device_idx: which GPU to run the simulation and rendering on
        """
        super(NavigateEnv, self).__init__(config_file=config_file,
                                          model_id=model_id,
                                          mode=mode,
                                          action_timestep=action_timestep,
                                          physics_timestep=physics_timestep,
                                          device_idx=device_idx,
                                          render_to_tensor=render_to_tensor)
        self.automatic_reset = automatic_reset

    def load_task_setup(self):
        """
        Load task setup, including initialization, termination conditino, reward, collision checking, discount factor
        """
        # initial and target pose
        self.initial_pos = np.array(self.config.get('initial_pos', [0, 0, 0]))
        self.initial_orn = np.array(self.config.get('initial_orn', [0, 0, 0]))
        self.target_pos = np.array(self.config.get('target_pos', [5, 5, 0]))
        self.target_orn = np.array(self.config.get('target_orn', [0, 0, 0]))

        self.initial_pos_z_offset = self.config.get('initial_pos_z_offset', 0.1)
        check_collision_distance = self.initial_pos_z_offset * 0.5
        # s = 0.5 * G * (t ** 2)
        check_collision_distance_time = np.sqrt(check_collision_distance / (0.5 * 9.8))
        self.check_collision_loop = int(check_collision_distance_time / self.physics_timestep)

        self.additional_states_dim = self.config.get('additional_states_dim', 0)
        self.goal_format = self.config.get('goal_format', 'polar')

        # termination condition
        self.dist_tol = self.config.get('dist_tol', 0.5)
        self.max_step = self.config.get('max_step', 500)
        self.max_collisions_allowed = self.config.get('max_collisions_allowed', 500)

        # reward
        self.reward_type = self.config.get('reward_type', 'l2')
        assert self.reward_type in ['geodesic', 'l2', 'sparse']

        self.success_reward = self.config.get('success_reward', 10.0)
        self.slack_reward = self.config.get('slack_reward', -0.01)

        # reward weight
        self.potential_reward_weight = self.config.get('potential_reward_weight', 1.0)
        self.collision_reward_weight = self.config.get('collision_reward_weight', -0.1)

        # ignore the agent's collision with these body ids
        self.collision_ignore_body_b_ids = set(self.config.get('collision_ignore_body_b_ids', []))
        # ignore the agent's collision with these link ids of itself
        self.collision_ignore_link_a_ids = set(self.config.get('collision_ignore_link_a_ids', []))

        # discount factor
        self.discount_factor = self.config.get('discount_factor', 0.99)

    def load_observation_space(self):
        """
        Load observation space
        """
        self.output = self.config['output']
        self.image_width = self.config.get('image_width', 128)
        self.image_height = self.config.get('image_height', 128)
        observation_space = OrderedDict()
        if 'sensor' in self.output:
            self.sensor_dim = self.additional_states_dim
            self.sensor_space = gym.spaces.Box(low=-np.inf,
                                               high=np.inf,
                                               shape=(self.sensor_dim,),
                                               dtype=np.float32)
            observation_space['sensor'] = self.sensor_space
        if 'rgb' in self.output:
            self.rgb_space = gym.spaces.Box(low=0.0,
                                            high=1.0,
                                            shape=(self.image_height, self.image_width, 3),
                                            dtype=np.float32)
            observation_space['rgb'] = self.rgb_space
        if 'depth' in self.output:
            self.depth_noise_rate = self.config.get('depth_noise_rate', 0.0)
            self.depth_low = self.config.get('depth_low', 0.5)
            self.depth_high = self.config.get('depth_high', 5.0)
            self.depth_space = gym.spaces.Box(low=0.0,
                                              high=1.0,
                                              shape=(self.image_height, self.image_width, 1),
                                              dtype=np.float32)
            observation_space['depth'] = self.depth_space
        if 'rgbd' in self.output:
            self.rgbd_space = gym.spaces.Box(low=0.0,
                                             high=1.0,
                                             shape=(self.image_height, self.image_width, 4),
                                             dtype=np.float32)
            observation_space['rgbd'] = self.rgbd_space
        if 'seg' in self.output:
            self.seg_space = gym.spaces.Box(low=0.0,
                                            high=1.0,
                                            shape=(self.image_height, self.image_width, 1),
                                            dtype=np.float32)
            observation_space['seg'] = self.seg_space
        if 'scan' in self.output:
            self.scan_noise_rate = self.config.get('scan_noise_rate', 0.0)
            self.n_horizontal_rays = self.config.get('n_horizontal_rays', 128)
            self.n_vertical_beams = self.config.get('n_vertical_beams', 1)
            assert self.n_vertical_beams == 1, 'scan can only handle one vertical beam for now'
            self.laser_linear_range = self.config.get('laser_linear_range', 10.0)
            self.laser_angular_range = self.config.get('laser_angular_range', 180.0)
            self.min_laser_dist = self.config.get('min_laser_dist', 0.05)
            self.laser_link_name = self.config.get('laser_link_name', 'scan_link')
            self.scan_space = gym.spaces.Box(low=0.0,
                                             high=1.0,
                                             shape=(self.n_horizontal_rays * self.n_vertical_beams, 1),
                                             dtype=np.float32)
            observation_space['scan'] = self.scan_space
        if 'rgb_filled' in self.output:  # use filler
            try:
                import torch.nn as nn
                import torch
                from torchvision import datasets, transforms
                from gibson2.learn.completion import CompletionNet
            except:
                raise Exception('Trying to use rgb_filled ("the goggle"), but torch is not installed. Try "pip install torch torchvision".')

            self.comp = CompletionNet(norm=nn.BatchNorm2d, nf=64)
            self.comp = torch.nn.DataParallel(self.comp).cuda()
            self.comp.load_state_dict(
                torch.load(os.path.join(gibson2.assets_path, 'networks', 'model.pth')))
            self.comp.eval()

        self.observation_space = gym.spaces.Dict(observation_space)

    def load_action_space(self):
        """
        Load action space
        """
        self.action_space = self.robots[0].action_space

    def load_visualization(self):
        """
        Load visualization, such as initial and target position, shortest path, etc
        """
        if self.mode != 'gui':
            return

        cyl_length = 0.2
        self.initial_pos_vis_obj = VisualMarker(visual_shape=p.GEOM_CYLINDER,
                                                rgba_color=[1, 0, 0, 0.3],
                                                radius=self.dist_tol,
                                                length=cyl_length,
                                                initial_offset=[0, 0, cyl_length / 2.0])
        self.target_pos_vis_obj = VisualMarker(visual_shape=p.GEOM_CYLINDER,
                                               rgba_color=[0, 0, 1, 0.3],
                                               radius=self.dist_tol,
                                               length=cyl_length,
                                               initial_offset=[0, 0, cyl_length / 2.0])
        self.initial_pos_vis_obj.load()
        self.target_pos_vis_obj.load()

        if self.scene.build_graph:
            self.num_waypoints_vis = 250
            self.waypoints_vis = [VisualMarker(visual_shape=p.GEOM_CYLINDER,
                                               rgba_color=[0, 1, 0, 0.3],
                                               radius=0.1,
                                               length=cyl_length,
                                               initial_offset=[0, 0, cyl_length / 2.0])
                                  for _ in range(self.num_waypoints_vis)]
            for waypoint in self.waypoints_vis:
                waypoint.load()

    def load_miscellaneous_variables(self):
        """
        Load miscellaneous variables for book keeping
        """
        self.current_step = 0
        self.collision_step = 0
        self.current_episode = 0
        self.floor_num = 0

    def load(self):
        """
        Load navigation environment
        """
        super(NavigateEnv, self).load()
        self.load_task_setup()
        self.load_observation_space()
        self.load_action_space()
        self.load_visualization()
        self.load_miscellaneous_variables()

    def global_to_local(self, pos):
        """
        Convert a 3D point in global frame to agent's local frame
        :param pos: a 3D point in global frame
        :return: the same 3D point in agent's local frame
        """
        return rotate_vector_3d(pos - self.robots[0].get_position(), *self.robots[0].get_rpy())

    def get_additional_states(self):
        """
        :return: non-perception observation, such as goal location
        """
        additional_states = self.global_to_local(self.target_pos)[:2]
        if self.goal_format == 'polar':
            additional_states = np.array(cartesian_to_polar(additional_states[0], additional_states[1]))

        # linear velocity along the x-axis
        linear_velocity = rotate_vector_3d(self.robots[0].get_linear_velocity(),
                                           *self.robots[0].get_rpy())[0]
        # angular velocity along the z-axis
        angular_velocity = rotate_vector_3d(self.robots[0].get_angular_velocity(),
                                            *self.robots[0].get_rpy())[2]
        additional_states = np.append(additional_states, [linear_velocity, angular_velocity])

        if self.config['task'] == 'reaching':
            end_effector_pos_local = self.global_to_local(self.robots[0].get_end_effector_position())
            additional_states = np.append(additional_states, end_effector_pos_local)

        assert additional_states.shape[0] == self.additional_states_dim, \
            'additional states dimension mismatch {} v.s. {}'.format(additional_states.shape[0], self.additional_states_dim)
        return additional_states

    def add_naive_noise_to_sensor(self, sensor_reading, noise_rate, noise_value=1.0):
        """
        Add naive sensor dropout to perceptual sensor, such as RGBD and LiDAR scan
        :param sensor_reading: raw sensor reading, range must be between [0.0, 1.0]
        :param noise_rate: how much noise to inject, 0.05 means 5% of the data will be replaced with noise_value
        :param noise_value: noise_value to overwrite raw sensor reading
        :return: sensor reading corrupted with noise
        """
        if noise_rate <= 0.0:
            return sensor_reading

        assert len(sensor_reading[(sensor_reading < 0.0) | (sensor_reading > 1.0)]) == 0,\
            'sensor reading has to be between [0.0, 1.0]'

        valid_mask = np.random.choice(2, sensor_reading.shape, p=[noise_rate, 1.0 - noise_rate])
        sensor_reading[valid_mask == 0] = noise_value
        return sensor_reading

    def get_depth(self):
        """
        :return: depth sensor reading, normalized to [0.0, 1.0]
        """
        depth = -self.simulator.renderer.render_robot_cameras(modes=('3d'))[0][:, :, 2:3]
        # 0.0 is a special value for invalid entries
        depth[depth < self.depth_low] = 0.0
        depth[depth > self.depth_high] = 0.0

        # re-scale depth to [0.0, 1.0]
        depth /= self.depth_high
        depth = self.add_naive_noise_to_sensor(depth, self.depth_noise_rate, noise_value=0.0)

        return depth

    def get_rgb(self):
        """
        :return: RGB sensor reading, normalized to [0.0, 1.0]
        """
        return self.simulator.renderer.render_robot_cameras(modes=('rgb'))[0][:, :, :3]

    def get_pc(self):
        """
        :return: pointcloud sensor reading
        """
        return self.simulator.renderer.render_robot_cameras(modes=('3d'))[0]

    def get_normal(self):
        """
        :return: surface normal reading
        """
        return self.simulator.renderer.render_robot_cameras(modes='normal')

    def get_seg(self):
        """
        :return: semantic segmentation mask, normalized to [0.0, 1.0]
        """
        seg = self.simulator.renderer.render_robot_cameras(modes='seg')[0][:, :, 0:1]
        if self.num_object_classes is not None:
            seg = np.clip(seg * 255.0 / self.num_object_classes, 0.0, 1.0)
        return seg

    def get_scan(self):
        """
        :return: LiDAR sensor reading, normalized to [0.0, 1.0]
        """
        laser_angular_half_range = self.laser_angular_range / 2.0
        if self.laser_link_name not in self.robots[0].parts:
            raise Exception('Trying to simulate LiDAR sensor, but laser_link_name cannot be found in the robot URDF file. Please add a link named laser_link_name at the intended laser pose. Feel free to check out assets/models/turtlebot/turtlebot.urdf and examples/configs/turtlebot_p2p_nav.yaml for examples.')
        laser_pose = self.robots[0].parts[self.laser_link_name].get_pose()
        angle = np.arange(-laser_angular_half_range / 180 * np.pi,
                          laser_angular_half_range / 180 * np.pi,
                          self.laser_angular_range / 180.0 * np.pi / self.n_horizontal_rays)
        unit_vector_local = np.array([[np.cos(ang), np.sin(ang), 0.0] for ang in angle])
        transform_matrix = quat2mat([laser_pose[6], laser_pose[3], laser_pose[4], laser_pose[5]])  # [x, y, z, w]
        unit_vector_world = transform_matrix.dot(unit_vector_local.T).T

        start_pose = np.tile(laser_pose[:3], (self.n_horizontal_rays, 1))
        start_pose += unit_vector_world * self.min_laser_dist
        end_pose = laser_pose[:3] + unit_vector_world * self.laser_linear_range
        results = p.rayTestBatch(start_pose, end_pose, 6)  # numThreads = 6

        hit_fraction = np.array([item[2] for item in results])  # hit fraction = [0.0, 1.0] of self.laser_linear_range
        hit_fraction = self.add_naive_noise_to_sensor(hit_fraction, self.scan_noise_rate)
        scan = np.expand_dims(hit_fraction, 1)

        xyz = hit_fraction[:, np.newaxis] * unit_vector_local * 10
        xyz = xyz[np.equal(np.isnan(xyz), False)]
        xyz = xyz.reshape(xyz.shape[0] // 3, -1)
        return xyz#scan

    def get_state(self, collision_links=[]):
        """
        :param collision_links: collisions from last time step
        :return: observation as a dictionary
        """
        state = OrderedDict()
        if 'sensor' in self.output:
            state['sensor'] = self.get_additional_states()
        if 'rgb' in self.output:
            state['rgb'] = self.get_rgb()
        if 'depth' in self.output:
            state['depth'] = self.get_depth()
        if 'pc' in self.output:
            state['pc'] = self.get_pc()
        if 'rgbd' in self.output:
            rgb = self.get_rgb()
            depth = self.get_depth()
            state['rgbd'] = np.concatenate((rgb, depth), axis=2)
        if 'normal' in self.output:
            state['normal'] = self.get_normal()
        if 'seg' in self.output:
            state['seg'] = self.get_seg()
        if 'rgb_filled' in self.output:
            with torch.no_grad():
                tensor = transforms.ToTensor()((state['rgb'] * 255).astype(np.uint8)).cuda()
                rgb_filled = self.comp(tensor[None, :, :, :])[0].permute(1, 2, 0).cpu().numpy()
                state['rgb_filled'] = rgb_filled
        if 'scan' in self.output:
            state['scan'] = self.get_scan()
        return state

    def run_simulation(self):
        """
        Run simulation for one action timestep (simulator_loop physics timestep)
        :return: collisions from this simulation
        """
        collision_links = []
        for _ in range(self.simulator_loop):
            self.simulator_step()
            collision_links.append(list(p.getContactPoints(bodyA=self.robots[0].robot_ids[0])))
        self.simulator.sync()

        return self.filter_collision_links(collision_links)

    def filter_collision_links(self, collision_links):
        """
        Filter out collisions that should be ignored
        :param collision_links: original collisions, a list of lists of collisions
        :return: filtered collisions
        """
        new_collision_links = []
        for collision_per_sim_step in collision_links:
            new_collision_per_sim_step = []
            for item in collision_per_sim_step:
                # ignore collision with body b
                if item[2] in self.collision_ignore_body_b_ids:
                    continue

                # ignore collision with robot link a
                if item[3] in self.collision_ignore_link_a_ids:
                    continue

                # ignore self collision with robot link a (body b is also robot itself)
                if item[2] == self.robots[0].robot_ids[0] and item[4] in self.collision_ignore_link_a_ids:
                    continue

                new_collision_per_sim_step.append(item)
            new_collision_links.append(new_collision_per_sim_step)
        return new_collision_links

    def get_position_of_interest(self):
        """
        Get position of interest.
        :return: If pointgoal task, return base position. If reaching task, return end effector position.
        """
        if self.config['task'] == 'pointgoal':
            return self.robots[0].get_position()
        elif self.config['task'] == 'reaching':
            return self.robots[0].get_end_effector_position()

    def get_shortest_path(self, from_initial_pos=False, entire_path=False):
        """
        :param from_initial_pos: whether source is initial position rather than current position
        :param entire_path: whether to return the entire shortest path
        :return: shortest path and geodesic distance to the target position
        """
        if from_initial_pos:
            source = self.initial_pos[:2]
        else:
            source = self.robots[0].get_position()[:2]
        target = self.target_pos[:2]
        return self.scene.get_shortest_path(self.floor_num, source, target, entire_path=entire_path)

    def get_geodesic_potential(self):
        """
        :return: geodesic distance to the target position
        """
        _, geodesic_dist = self.get_shortest_path()
        return geodesic_dist

    def get_l2_potential(self):
        """
        :return: L2 distance to the target position
        """
        return l2_distance(self.target_pos, self.get_position_of_interest())

    def is_goal_reached(self):
        return l2_distance(self.get_position_of_interest(), self.target_pos) < self.dist_tol

    def get_reward(self, collision_links=[], action=None, info={}):
        """
        :param collision_links: collisions from last time step
        :param action: last action
        :param info: a dictionary to store additional info
        :return: reward, info
        """
        collision_links_flatten = [item for sublist in collision_links for item in sublist]
        reward = self.slack_reward  # |slack_reward| = 0.01 per step

        if self.reward_type == 'l2':
            new_potential = self.get_l2_potential()
        elif self.reward_type == 'geodesic':
            new_potential = self.get_geodesic_potential()
        potential_reward = self.potential - new_potential
        reward += potential_reward * self.potential_reward_weight  # |potential_reward| ~= 0.1 per step
        self.potential = new_potential

        collision_reward = float(len(collision_links_flatten) > 0)
        self.collision_step += int(collision_reward)
        reward += collision_reward * self.collision_reward_weight  # |collision_reward| ~= 1.0 per step if collision

        if self.is_goal_reached():
            reward += self.success_reward  # |success_reward| = 10.0 per step
        return reward, info

    def get_termination(self, collision_links=[], action=None, info={}):
        """
        :param collision_links: collisions from last time step
        :param info: a dictionary to store additional info
        :return: done, info
        """
        done = False

        # goal reached
        if self.is_goal_reached():
            done = True
            info['success'] = True

        # max collisions reached
        if self.collision_step > self.max_collisions_allowed:
            done = True
            info['success'] = False

        # time out
        elif self.current_step >= self.max_step:
            done = True
            info['success'] = False

        if done:
            info['episode_length'] = self.current_step
            info['collision_step'] = self.collision_step
            info['path_length'] = self.path_length
            info['spl'] = float(info['success']) * min(1.0, self.geodesic_dist / self.path_length)

        return done, info

    def before_simulation(self):
        """
        Cache bookkeeping data before simulation
        :return: cache
        """
        return {'robot_position': self.robots[0].get_position()}

    def after_simulation(self, cache, collision_links):
        """
        Accumulate evaluation stats
        :param cache: cache returned from before_simulation
        :param collision_links: collisions from last time step
        """
        old_robot_position = cache['robot_position'][:2]
        new_robot_position = self.robots[0].get_position()[:2]
        self.path_length += l2_distance(old_robot_position, new_robot_position)

    def step_visualization(self):
        if self.mode != 'gui':
            return

        self.initial_pos_vis_obj.set_position(self.initial_pos)
        self.target_pos_vis_obj.set_position(self.target_pos)

        if self.scene.build_graph:
            shortest_path, _ = self.get_shortest_path(entire_path=True)
            floor_height = 0.0 if self.floor_num is None else self.scene.get_floor_height(self.floor_num)
            num_nodes = min(self.num_waypoints_vis, shortest_path.shape[0])
            for i in range(num_nodes):
                self.waypoints_vis[i].set_position(pos=np.array([shortest_path[i][0],
                                                                 shortest_path[i][1],
                                                                 floor_height]))
            for i in range(num_nodes, self.num_waypoints_vis):
                self.waypoints_vis[i].set_position(pos=np.array([0.0, 0.0, 100.0]))

    def step(self, action):
        """
        apply robot's action and get state, reward, done and info, following OpenAI gym's convention
        :param action: a list of control signals
        :return: state, reward, done, info
        """
        self.current_step += 1
        if action is not None:
            self.robots[0].apply_action(action)
        cache = self.before_simulation()
        collision_links = self.run_simulation()
        self.after_simulation(cache, collision_links)

        state = self.get_state(collision_links)
        info = {}
        reward, info = self.get_reward(collision_links, action, info)
        done, info = self.get_termination(collision_links, action, info)
        self.step_visualization()

        if done and self.automatic_reset:
            info['last_observation'] = state
            state = self.reset()
        return state, reward, done, info

    def reset_agent(self):
        """
        Reset the robot's joint configuration and base pose until no collision
        """
        reset_success = False
        max_trials = 100
        for _ in range(max_trials):
            self.reset_initial_and_target_pos()
            if self.test_valid_position('robot', self.robots[0], self.initial_pos, self.initial_orn) and \
                    self.test_valid_position('robot', self.robots[0], self.target_pos):
                reset_success = True
                break

        if not reset_success:
            logging.warning("WARNING: Failed to reset robot without collision")

        self.land('robot', self.robots[0], self.initial_pos, self.initial_orn)

    def reset_initial_and_target_pos(self):
        """
        Reset initial_pos, initial_orn and target_pos
        """
        return

    def check_collision(self, body_id):
        """
        :param body_id: pybullet body id
        :return: whether the given body_id has no collision
        """
        for _ in range(self.check_collision_loop):
            self.simulator_step()
            collisions = list(p.getContactPoints(bodyA=body_id))

            if logging.root.level <= logging.DEBUG: #Only going into this if it is for logging --> efficiency
                for item in collisions:
                    logging.debug('bodyA:{}, bodyB:{}, linkA:{}, linkB:{}'.format(item[1], item[2], item[3], item[4]))

            if len(collisions) > 0:
                return False
        return True

    def set_pos_orn_with_z_offset(self, obj, pos, orn=None, offset=None):
        """
        Reset position and orientation for the robot or the object
        :param obj: an instance of robot or object
        :param pos: position
        :param orn: orientation
        :param offset: z offset
        """
        if orn is None:
            orn = np.array([0, 0, np.random.uniform(0, np.pi * 2)])

        if offset is None:
            offset = self.initial_pos_z_offset

        obj.set_position_orientation([pos[0], pos[1], pos[2] + offset],
                                     quatToXYZW(euler2quat(*orn), 'wxyz'))

    def test_valid_position(self, obj_type, obj, pos, orn=None):
        """
        Test if the robot or the object can be placed with no collision
        :param obj_type: string "robot" or "obj"
        :param obj: an instance of robot or object
        :param pos: position
        :param orn: orientation
        :return: validity
        """
        assert obj_type in ['robot', 'obj']

        self.set_pos_orn_with_z_offset(obj, pos, orn)

        if obj_type == 'robot':
            obj.robot_specific_reset()
            obj.keep_still()

        body_id = obj.robot_ids[0] if obj_type == 'robot' else obj.body_id
        return self.check_collision(body_id)

    def land(self, obj_type, obj, pos, orn):
        """
        Land the robot or the object onto the floor, given a valid position and orientation
        :param obj_type: string "robot" or "obj"
        :param obj: an instance of robot or object
        :param pos: position
        :param orn: orientation
        """
        assert obj_type in ['robot', 'obj']

        self.set_pos_orn_with_z_offset(obj, pos, orn)

        if obj_type == 'robot':
            obj.robot_specific_reset()
            obj.keep_still()

        body_id = obj.robot_ids[0] if obj_type == 'robot' else obj.body_id

        land_success = False
        # land for maximum 1 second, should fall down ~5 meters
        max_simulator_step = int(1.0 / self.physics_timestep)
        for _ in range(max_simulator_step):
            self.simulator_step()
            if len(p.getContactPoints(bodyA=body_id)) > 0:
                land_success = True
                break

        if not land_success:
            print("WARNING: Failed to land")

        if obj_type == 'robot':
            obj.robot_specific_reset()

    def reset_variables(self):
        """
        Reset bookkeeping variables for the next new episode
        """
        self.current_episode += 1
        self.current_step = 0
        self.collision_step = 0
        self.path_length = 0.0
        self.geodesic_dist = self.get_geodesic_potential()

    def reset(self):
        """
        Reset episode
        """
        self.reset_agent()
        self.simulator.sync()
        state = self.get_state()
        if self.reward_type == 'l2':
            self.potential = self.get_l2_potential()
        elif self.reward_type == 'geodesic':
            self.potential = self.get_geodesic_potential()
        self.reset_variables()
        self.step_visualization()

        return state
Beispiel #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataroot', required=True, help='path to dataset')
    parser.add_argument('--debug', action='store_true', help='debug mode')
    parser.add_argument('--imgsize', type=int, default=256, help='image size')
    parser.add_argument('--nf', type=int, default=64, help='number of filters')
    parser.add_argument('--batchsize', type=int, default=20, help='batchsize')
    parser.add_argument('--workers',
                        type=int,
                        default=9,
                        help='number of workers')
    parser.add_argument('--nepoch',
                        type=int,
                        default=50,
                        help='number of epochs')
    parser.add_argument('--lr',
                        type=float,
                        default=2e-5,
                        help='learning rate, default=0.002')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.5,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--outf',
                        type=str,
                        default="filler_pano_pc_full",
                        help='output folder')
    parser.add_argument('--model', type=str, default="", help='model path')
    parser.add_argument('--cepoch', type=int, default=0, help='current epoch')
    parser.add_argument('--loss',
                        type=str,
                        default="perceptual",
                        help='l1 only')
    parser.add_argument('--init', type=str, default="iden", help='init method')
    parser.add_argument('--l1', type=float, default=0, help='add l1 loss')
    parser.add_argument('--color_coeff',
                        type=float,
                        default=0,
                        help='add color match loss')
    parser.add_argument('--unfiller', action='store_true', help='debug mode')
    parser.add_argument('--joint', action='store_true', help='debug mode')
    parser.add_argument('--use_depth',
                        action='store_true',
                        default=False,
                        help='debug mode')
    parser.add_argument('--zoom', type=int, default=1, help='debug mode')
    parser.add_argument('--patchsize',
                        type=int,
                        default=256,
                        help='debug mode')

    mean = torch.from_numpy(
        np.array([0.57441127, 0.54226291,
                  0.50356019]).astype(np.float32)).clone()
    opt = parser.parse_args()
    print(opt)
    writer = SummaryWriter(opt.outf + '/runs/' +
                           datetime.now().strftime('%B%d  %H:%M:%S'))
    try:
        os.makedirs(opt.outf)
    except OSError:
        pass

    zoom = opt.zoom
    patchsize = opt.patchsize

    tf = transforms.Compose([
        transforms.ToTensor(),
    ])

    mist_tf = transforms.Compose([
        transforms.ToTensor(),
    ])

    d = PairDataset(root=opt.dataroot, transform=tf, mist_transform=mist_tf)
    d_test = PairDataset(root=opt.dataroot,
                         transform=tf,
                         mist_transform=mist_tf,
                         train=False)

    cudnn.benchmark = True

    dataloader = torch.utils.data.DataLoader(d,
                                             batch_size=opt.batchsize,
                                             shuffle=True,
                                             num_workers=int(opt.workers),
                                             drop_last=True,
                                             pin_memory=False)
    dataloader_test = torch.utils.data.DataLoader(d_test,
                                                  batch_size=opt.batchsize,
                                                  shuffle=True,
                                                  num_workers=int(opt.workers),
                                                  drop_last=True,
                                                  pin_memory=False)

    img = Variable(torch.zeros(opt.batchsize, 3, 1024, 2048)).cuda()
    maskv = Variable(torch.zeros(opt.batchsize, 2, 1024, 2048)).cuda()
    img_original = Variable(torch.zeros(opt.batchsize, 3, 1024, 2048)).cuda()
    label = Variable(torch.LongTensor(opt.batchsize * 4)).cuda()

    comp = CompletionNet(norm=nn.BatchNorm2d, nf=opt.nf)

    current_epoch = opt.cepoch

    comp = torch.nn.DataParallel(comp).cuda()

    if opt.init == 'iden':
        comp.apply(identity_init)
    else:
        comp.apply(weights_init)

    if opt.model != '':
        comp.load_state_dict(torch.load(opt.model))
        # dis.load_state_dict(torch.load(opt.model.replace("G", "D")))
        current_epoch = opt.cepoch

    if opt.unfiller:
        comp2 = CompletionNet(norm=nn.BatchNorm2d, nf=64)
        comp2 = torch.nn.DataParallel(comp2).cuda()
        if opt.model != '':
            comp2.load_state_dict(torch.load(opt.model.replace('G', 'G2')))
        optimizerG2 = torch.optim.Adam(comp2.parameters(),
                                       lr=opt.lr,
                                       betas=(opt.beta1, 0.999))

    l2 = nn.MSELoss()
    # if opt.loss == 'train_init':
    #    params = list(comp.parameters())
    #    sel = np.random.choice(len(params), len(params)/2, replace=False)
    #    params_sel = [params[i] for i in sel]
    #    optimizerG = torch.optim.Adam(params_sel, lr = opt.lr, betas = (opt.beta1, 0.999))
    #
    # else:
    optimizerG = torch.optim.Adam(comp.parameters(),
                                  lr=opt.lr,
                                  betas=(opt.beta1, 0.999))

    curriculum = (
        200000, 300000
    )  # step to start D training and G training, slightly different from the paper
    alpha = 0.004

    errG_data = 0
    errD_data = 0

    vgg16 = models.vgg16(pretrained=False)
    vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
    feat = vgg16.features
    p = torch.nn.DataParallel(Perceptual(feat)).cuda()

    for param in p.parameters():
        param.requires_grad = False

    imgnet_mean = torch.from_numpy(
        np.array([0.485, 0.456, 0.406]).astype(np.float32)).clone()
    imgnet_std = torch.from_numpy(
        np.array([0.229, 0.224, 0.225]).astype(np.float32)).clone()

    imgnet_mean_img = Variable(
        imgnet_mean.view(1, 3, 1, 1).repeat(opt.batchsize * 4, 1, patchsize,
                                            patchsize)).cuda()
    imgnet_std_img = Variable(
        imgnet_std.view(1, 3, 1, 1).repeat(opt.batchsize * 4, 1, patchsize,
                                           patchsize)).cuda()
    test_loader_enum = enumerate(dataloader_test)

    test_loader_enum = enumerate(dataloader_test)
    for epoch in range(current_epoch, opt.nepoch):
        for i, data in enumerate(dataloader, 0):
            optimizerG.zero_grad()
            source = data[0]
            source_depth = data[1]
            target = data[2]
            step = i + epoch * len(dataloader)

            mask = (torch.sum(source[:, :3, :, :], 1) > 0).float().unsqueeze(1)
            # img_mean = torch.sum(torch.sum(source[:,:3,:,:], 2),2) / torch.sum(torch.sum(mask, 2),2).view(opt.batchsize,1)

            source[:, :3, :, :] += (1 - mask.repeat(1, 3, 1, 1)) * mean.view(
                1, 3, 1, 1).repeat(opt.batchsize, 1, 1024, 2048)
            source_depth = source_depth[:, :, :, 0].unsqueeze(1)
            # print(source_depth.size(), mask.size())
            source_depth = torch.cat([source_depth, mask], 1)
            img.data.copy_(source)
            maskv.data.copy_(source_depth)
            img_original.data.copy_(target)
            imgc, maskvc, img_originalc = crop(img, maskv, img_original, zoom,
                                               patchsize)
            # from IPython import embed; embed()
            recon = comp(imgc, maskvc)

            if opt.loss == "train_init":
                loss = l2(recon, imgc[:, :3, :, :])
            elif opt.loss == 'l1':
                loss = l2(recon, img_originalc)
            elif opt.loss == 'perceptual':
                loss = l2(p(recon),
                          p(img_originalc).detach()) + opt.l1 * l2(
                              recon, img_originalc)
            elif opt.loss == 'color_stable':
                loss = l2(
                    p(
                        recon.view(recon.size(0) * 3, 1, patchsize,
                                   patchsize).repeat(1, 3, 1, 1)),
                    p(
                        img_originalc.view(
                            img_originalc.size(0) * 3, 1, patchsize,
                            patchsize).repeat(1, 3, 1, 1)).detach())
            elif opt.loss == 'color_correction':
                recon_percept = p((recon - imgnet_mean_img) / imgnet_std_img)
                org_percept = p((img_originalc - imgnet_mean_img) /
                                (imgnet_std_img)).detach()
                loss = l2(recon_percept, org_percept)
                for scale in [32]:
                    img_originalc_patch = img_originalc.view(
                        opt.batchsize * 4, 3, patchsize // scale, scale,
                        patchsize // scale,
                        scale).transpose(4, 3).contiguous().view(
                            opt.batchsize * 4, 3, patchsize // scale,
                            patchsize // scale, -1)
                    recon_patch = recon.view(
                        opt.batchsize * 4, 3, patchsize // scale, scale,
                        patchsize // scale,
                        scale).transpose(4, 3).contiguous().view(
                            opt.batchsize * 4, 3, patchsize // scale,
                            patchsize // scale, -1)
                    img_originalc_patch_mean = img_originalc_patch.mean(dim=-1)
                    recon_patch_mean = recon_patch.mean(dim=-1)
                    # recon_patch_cov = []
                    # img_originalc_patch_cov = []

                    # for j in range(3):
                    #    recon_patch_cov.append((recon_patch * recon_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
                    #    img_originalc_patch_cov.append((img_originalc_patch * img_originalc_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))

                    # recon_patch_cov_cat = torch.cat(recon_patch_cov,1)
                    # img_originalc_patch_cov_cat = torch.cat(img_originalc_patch_cov, 1)

                    color_loss = l2(
                        recon_patch_mean, img_originalc_patch_mean
                    )  # + l2(recon_patch_cov_cat, img_originalc_patch_cov_cat.detach())

                    loss += opt.color_coeff * color_loss

                    print("color loss %f" % color_loss.data[0])

            loss.backward(retain_graph=True)

            if opt.unfiller:
                optimizerG2.zero_grad()

                recon2 = comp2(img_originalc, maskvc)

                if not opt.joint:
                    recon2_percept = p(
                        (recon2 - imgnet_mean_img) / imgnet_std_img)
                    recon_percept = p(
                        (recon - imgnet_mean_img) / imgnet_std_img)
                    loss2 = l2(recon2_percept, recon_percept.detach())
                else:
                    recon_percept = p(
                        (recon - imgnet_mean_img) / imgnet_std_img)
                    z = Variable(torch.zeros(recon_percept.size()).cuda())
                    recon2_percept = p(
                        (recon2 - imgnet_mean_img) / imgnet_std_img)

                    loss2 = l2(recon2_percept - recon_percept, z)

                    loss2 += 0.2 * l2(recon2_percept, org_percept)

                for scale in [32]:
                    img_originalc_patch = recon.detach().view(
                        opt.batchsize * 4, 3, patchsize / scale, scale,
                        patchsize / scale,
                        scale).transpose(4, 3).contiguous().view(
                            opt.batchsize * 4, 3, patchsize / scale,
                            patchsize / scale, -1)
                    recon2_patch = recon2.view(
                        opt.batchsize * 4, 3, patchsize / scale, scale,
                        patchsize / scale,
                        scale).transpose(4, 3).contiguous().view(
                            opt.batchsize * 4, 3, patchsize / scale,
                            patchsize / scale, -1)
                    img_originalc_patch_mean = img_originalc_patch.mean(dim=-1)
                    recon2_patch_mean = recon2_patch.mean(dim=-1)
                    recon2_patch_cov = []
                    img_originalc_patch_cov = []

                    for j in range(3):
                        recon2_patch_cov.append(
                            (recon2_patch * recon2_patch[:, j:j + 1].repeat(
                                1, 3, 1, 1, 1)).mean(dim=-1))
                        img_originalc_patch_cov.append(
                            (img_originalc_patch *
                             img_originalc_patch[:, j:j + 1].repeat(
                                 1, 3, 1, 1, 1)).mean(dim=-1))

                    recon2_patch_cov_cat = torch.cat(recon2_patch_cov, 1)
                    img_originalc_patch_cov_cat = torch.cat(
                        img_originalc_patch_cov, 1)

                    z = Variable(
                        torch.zeros(img_originalc_patch_mean.size()).cuda())
                    if opt.joint:
                        color_loss = l2(
                            recon2_patch_mean - img_originalc_patch_mean, z)
                    else:
                        color_loss = l2(recon2_patch_mean,
                                        img_originalc_patch_mean)

                    loss2 += opt.color_coeff * color_loss

                    print("color loss %f" % color_loss.data[0])

                loss2 = loss2 * 0.3
                loss2.backward(retain_graph=True)
                print("loss2 %f" % loss2.data[0])
                optimizerG2.step()

                if i % 10 == 0:
                    writer.add_scalar('MSEloss2', loss2.data[0], step)

            # from IPython import embed; embed()
            if opt.loss == "train_init":
                for param in comp.parameters():
                    if len(param.size()) == 4:
                        # print(param.size())
                        nk = param.size()[2] // 2
                        if nk > 5:
                            param.grad[:nk, :, :, :] = 0

            optimizerG.step()

            print('[%d/%d][%d/%d] %d MSEloss: %f' %
                  (epoch, opt.nepoch, i, len(dataloader), step,
                   loss.data.item()))

            if i % 200 == 0:

                test_i, test_data = next(test_loader_enum)
                if test_i > len(dataloader_test) - 5:
                    test_loader_enum = enumerate(dataloader_test)

                source = test_data[0]
                source_depth = test_data[1]
                target = test_data[2]

                mask = (torch.sum(source[:, :3, :, :], 1) >
                        0).float().unsqueeze(1)

                source[:, :3, :, :] += (
                    1 - mask.repeat(1, 3, 1, 1)) * mean.view(
                        1, 3, 1, 1).repeat(opt.batchsize, 1, 1024, 2048)
                source_depth = source_depth[:, :, :, 0].unsqueeze(1)
                source_depth = torch.cat([source_depth, mask], 1)
                img.data.copy_(source)
                maskv.data.copy_(source_depth)
                img_original.data.copy_(target)
                imgc, maskvc, img_originalc = crop(img, maskv, img_original,
                                                   zoom, patchsize)
                comp.eval()
                recon = comp(imgc, maskvc)
                comp.train()

                if opt.unfiller:
                    comp2.eval()
                    # maskvc.data.fill_(0)
                    recon2 = comp2(img_originalc, maskvc)
                    comp2.train()
                    visual = torch.cat([
                        imgc.data[:, :3, :, :], recon.data, recon2.data,
                        img_originalc.data
                    ], 3)
                else:
                    visual = torch.cat([
                        imgc.data[:, :3, :, :], recon.data, img_originalc.data
                    ], 3)

                visual = vutils.make_grid(visual, normalize=True)
                writer.add_image('image', visual, step)
                vutils.save_image(visual,
                                  '%s/compare%d_%d.png' % (opt.outf, epoch, i),
                                  nrow=1)

            if i % 10 == 0:
                writer.add_scalar('MSEloss', loss.data[0], step)
                writer.add_scalar('G_loss', errG_data, step)
                writer.add_scalar('D_loss', errD_data, step)

            if i % 2000 == 0:
                torch.save(comp.state_dict(),
                           '%s/compG_epoch%d_%d.pth' % (opt.outf, epoch, i))

                if opt.unfiller:
                    torch.save(
                        comp2.state_dict(),
                        '%s/compG2_epoch%d_%d.pth' % (opt.outf, epoch, i))
Beispiel #8
0
    def __init__(self, port, imgs, depths, target, target_poses, scale_up, semantics=None, \
                 gui=True,  use_filler=True, gpu_idx=0, windowsz=256, env = None):

        self.env = env
        self.roll, self.pitch, self.yaw = 0, 0, 0
        self.quat = [1, 0, 0, 0]
        self.x, self.y, self.z = 0, 0, 0
        self.fps = 0
        self.mousex, self.mousey = 0.5, 0.5
        self.org_pitch, self.org_yaw, self.org_roll = 0, 0, 0
        self.org_x, self.org_y, self.org_z = 0, 0, 0
        self.clickstart = (0, 0)
        self.mousedown = False
        self.overlay = False
        self.show_depth = False

        self.port = port
        self._context_phys = zmq.Context()
        self._context_mist = zmq.Context()
        self._context_dept = zmq.Context()  ## Channel for smoothed depth
        self._context_norm = zmq.Context()  ## Channel for smoothed depth
        self._context_semt = zmq.Context()
        self.env = env

        self._require_semantics = 'semantics' in self.env.config[
            "output"]  #configs.View.SEMANTICS in configs.ViewComponent.getComponents()
        self._require_normal = 'normal' in self.env.config[
            "output"]  #configs.View.NORMAL in configs.ViewComponent.getComponents()

        self.socket_mist = self._context_mist.socket(zmq.REQ)
        self.socket_mist.connect("tcp://localhost:{}".format(self.port - 1))
        #self.socket_dept = self._context_dept.socket(zmq.REQ)
        #self.socket_dept.connect("tcp://localhost:{}".format(5555 - 1))
        if self._require_normal:
            self.socket_norm = self._context_norm.socket(zmq.REQ)
            self.socket_norm.connect("tcp://localhost:{}".format(self.port -
                                                                 2))
        if self._require_semantics:
            self.socket_semt = self._context_semt.socket(zmq.REQ)
            self.socket_semt.connect("tcp://localhost:{}".format(self.port -
                                                                 3))

        self.target_poses = target_poses
        self.pose_locations = np.array(
            [tp[:3, -1] for tp in self.target_poses])

        self.relative_poses = [
            np.dot(np.linalg.inv(tg), self.target_poses[0])
            for tg in target_poses
        ]

        self.imgs = imgs
        self.depths = depths
        self.target = target
        self.semantics = semantics
        self.model = None
        self.old_topk = set([])
        self.k = 5
        self.use_filler = use_filler

        self.showsz = windowsz
        self.capture_count = 0

        #print(self.showsz)
        #self.show   = np.zeros((self.showsz,self.showsz * 2,3),dtype='uint8')
        #self.show_rgb   = np.zeros((self.showsz,self.showsz * 2,3),dtype='uint8')

        self.show = np.zeros((self.showsz, self.showsz, 3), dtype='uint8')
        self.show_rgb = np.zeros((self.showsz, self.showsz, 3), dtype='uint8')
        self.show_semantics = np.zeros((self.showsz, self.showsz, 3),
                                       dtype='uint8')

        self.show_prefilled = np.zeros((self.showsz, self.showsz, 3),
                                       dtype='uint8')
        self.surface_normal = np.zeros((self.showsz, self.showsz, 3),
                                       dtype='uint8')

        self.semtimg_count = 0

        if "fast_lq_render" in self.env.config and self.env.config[
                "fast_lq_render"] == True:
            comp = CompletionNet(norm=nn.BatchNorm2d,
                                 nf=24,
                                 skip_first_bn=True)
        else:
            comp = CompletionNet(norm=nn.BatchNorm2d, nf=64)
        comp = torch.nn.DataParallel(comp).cuda()
        #comp.load_state_dict(torch.load(os.path.join(assets_file_dir, "model_{}.pth".format(self.env.config["resolution"]))))

        if self.env.config["resolution"] <= 64:
            res = 64
        elif self.env.config["resolution"] <= 128:
            res = 128
        elif self.env.config["resolution"] <= 256:
            res = 256
        else:
            res = 512

        if "fast_lq_render" in self.env.config and self.env.config[
                "fast_lq_render"] == True:
            comp.load_state_dict(
                torch.load(
                    os.path.join(assets_file_dir,
                                 "model_small_{}.pth".format(res))))
        else:
            comp.load_state_dict(
                torch.load(
                    os.path.join(assets_file_dir, "model_{}.pth".format(res))))

        #comp.load_state_dict(torch.load(os.path.join(file_dir, "model.pth")))
        #comp.load_state_dict(torch.load(os.path.join(file_dir, "model_large.pth")))
        self.model = comp.module
        self.model.eval()

        if not self.env.config["use_filler"]:
            self.model = None

        self.imgs_topk = None
        self.depths_topk = None
        self.relative_poses_topk = None
        self.old_topk = None

        self.imgv = Variable(torch.zeros(1, 3, self.showsz, self.showsz),
                             volatile=True).cuda()
        self.maskv = Variable(torch.zeros(1, 2, self.showsz, self.showsz),
                              volatile=True).cuda()
        self.mean = torch.from_numpy(
            np.array([0.57441127, 0.54226291, 0.50356019]).astype(np.float32))
        self.mean = self.mean.view(3, 1, 1).repeat(1, self.showsz, self.showsz)

        if gui and not self.env.config["display_ui"]:
            self.renderToScreenSetup()
Beispiel #9
0
    for k, v in uuids:
        #print(k,v)
        data = d[v]
        source = data[0][0]
        target = data[1]
        target_depth = data[3]
        source_depth = data[2][0]
        pose = data[-1][0].numpy()
        targets.append(target)
        poses.append(pose)
        sources.append(target)
        source_depths.append(target_depth)

    model = None
    if opt.model != '':
        comp = CompletionNet()
        comp = torch.nn.DataParallel(comp).cuda()
        comp.load_state_dict(torch.load(opt.model))
        model = comp.module
        model.eval()
    print(model)
    print('target', poses, poses[0])
    #print('no.1 pose', poses, poses[1])
    # print(source_depth)
    print(sources[0].shape, source_depths[0].shape)

    show_target(target)

    renderer = PCRenderer(5556, sources, source_depths, target, rts)
    #renderer.renderToScreen(sources, source_depths, poses, model, target, target_depth, rts)
    renderer.renderOffScreenSetup()
Beispiel #10
0
class VisionSensor(BaseSensor):
    """
    Vision sensor (including rgb, rgb_filled, depth, 3d, seg, normal, optical flow, scene flow)
    """
    def __init__(self, env, modalities):
        super(VisionSensor, self).__init__(env)
        self.modalities = modalities
        self.raw_modalities = self.get_raw_modalities(modalities)
        self.image_width = self.config.get('image_width', 128)
        self.image_height = self.config.get('image_height', 128)

        self.depth_noise_rate = self.config.get('depth_noise_rate', 0.0)
        self.depth_low = self.config.get('depth_low', 0.5)
        self.depth_high = self.config.get('depth_high', 5.0)

        self.noise_model = DropoutSensorNoise(env)
        self.noise_model.set_noise_rate(self.depth_noise_rate)
        self.noise_model.set_noise_value(0.0)

        if 'rgb_filled' in modalities:
            try:
                import torch.nn as nn
                import torch
                from torchvision import transforms
                from gibson2.learn.completion import CompletionNet
            except ImportError:
                raise Exception(
                    'Trying to use rgb_filled ("the goggle"), but torch is not installed. Try "pip install torch torchvision".'
                )

            self.comp = CompletionNet(norm=nn.BatchNorm2d, nf=64)
            self.comp = torch.nn.DataParallel(self.comp).cuda()
            self.comp.load_state_dict(
                torch.load(
                    os.path.join(gibson2.assets_path, 'networks',
                                 'model.pth')))
            self.comp.eval()

    def get_raw_modalities(self, modalities):
        """
        Helper function that gathers raw modalities (e.g. depth is based on 3d)

        :return: raw modalities to query the renderer
        """
        raw_modalities = []
        if 'rgb' in modalities or 'rgb_filled' in modalities:
            raw_modalities.append('rgb')
        if 'depth' in modalities or '3d' in modalities:
            raw_modalities.append('3d')
        if 'seg' in modalities:
            raw_modalities.append('seg')
        if 'normal' in modalities:
            raw_modalities.append('normal')
        if 'optical_flow' in modalities:
            raw_modalities.append('optical_flow')
        if 'scene_flow' in modalities:
            raw_modalities.append('scene_flow')
        return raw_modalities

    def get_rgb(self, raw_vision_obs):
        """
        :return: RGB sensor reading, normalized to [0.0, 1.0]
        """
        return raw_vision_obs['rgb'][:, :, :3]

    def get_rgb_filled(self, raw_vision_obs):
        """
        :return: RGB-filled sensor reading by passing through the "Goggle" neural network
        """
        rgb = self.get_rgb(raw_vision_obs)
        with torch.no_grad():
            tensor = transforms.ToTensor()((rgb * 255).astype(np.uint8)).cuda()
            rgb_filled = self.comp(tensor[None, :, :, :])[0]
            return rgb_filled.permute(1, 2, 0).cpu().numpy()

    def get_depth(self, raw_vision_obs):
        """
        :return: depth sensor reading, normalized to [0.0, 1.0]
        """
        depth = -raw_vision_obs['3d'][:, :, 2:3]
        # 0.0 is a special value for invalid entries
        depth[depth < self.depth_low] = 0.0
        depth[depth > self.depth_high] = 0.0

        # re-scale depth to [0.0, 1.0]
        depth /= self.depth_high
        depth = self.noise_model.add_noise(depth)

        return depth

    def get_pc(self, raw_vision_obs):
        """
        :return: pointcloud sensor reading
        """
        return raw_vision_obs['3d'][:, :, :3]

    def get_optical_flow(self, raw_vision_obs):
        """
        :return: optical flow sensor reading
        """
        return raw_vision_obs['optical_flow'][:, :, :3]

    def get_scene_flow(self, raw_vision_obs):
        """
        :return: scene flow sensor reading
        """
        return raw_vision_obs['scene_flow'][:, :, :3]

    def get_normal(self, raw_vision_obs):
        """
        :return: surface normal reading
        """
        return raw_vision_obs['normal'][:, :, :3]

    def get_seg(self, raw_vision_obs):
        """
        :return: semantic segmentation mask, normalized to [0.0, 1.0]
        """
        seg = raw_vision_obs['seg'][:, :, 0:1]
        return seg

    def get_obs(self, env):
        """
        Get vision sensor reading

        :return: vision sensor reading
        """
        raw_vision_obs = env.simulator.renderer.render_robot_cameras(
            modes=self.raw_modalities)

        raw_vision_obs = {
            mode: value
            for mode, value in zip(self.raw_modalities, raw_vision_obs)
        }

        vision_obs = OrderedDict()
        if 'rgb' in self.modalities:
            vision_obs['rgb'] = self.get_rgb(raw_vision_obs)
        if 'rgb_filled' in self.modalities:
            vision_obs['rgb_filled'] = self.get_rgb_filled(raw_vision_obs)
        if 'depth' in self.modalities:
            vision_obs['depth'] = self.get_depth(raw_vision_obs)
        if 'pc' in self.modalities:
            vision_obs['pc'] = self.get_pc(raw_vision_obs)
        if 'optical_flow' in self.modalities:
            vision_obs['optical_flow'] = self.get_optical_flow(raw_vision_obs)
        if 'scene_flow' in self.modalities:
            vision_obs['scene_flow'] = self.get_scene_flow(raw_vision_obs)
        if 'normal' in self.modalities:
            vision_obs['normal'] = self.get_normal(raw_vision_obs)
        if 'seg' in self.modalities:
            vision_obs['seg'] = self.get_seg(raw_vision_obs)
        return vision_obs