Example #1
0
class Point2DEnv(MultitaskEnv, Serializable):
    """
    A little 2D point whose life goal is to reach a target.
    """
    def __init__(
            self,
            render_dt_msec=0,
            action_l2norm_penalty=0,  # disabled for now
            render_onscreen=True,
            render_size=84,
            reward_type="dense",
            target_radius=0.5,
            boundary_dist=4,
            ball_radius=0.25,
            walls=[],
            fixed_goal=None,
            randomize_position_on_reset=True,
            **kwargs):
        print("WARNING, ignoring kwargs:", kwargs)
        self.quick_init(locals())
        self.render_dt_msec = render_dt_msec
        self.action_l2norm_penalty = action_l2norm_penalty
        self.render_onscreen = render_onscreen
        self.render_size = render_size
        self.reward_type = reward_type
        self.target_radius = target_radius
        self.boundary_dist = boundary_dist
        self.ball_radius = ball_radius
        self.walls = walls
        self.fixed_goal = fixed_goal
        self.randomize_position_on_reset = randomize_position_on_reset

        self._max_episode_steps = 50
        self.max_target_distance = self.boundary_dist - self.target_radius

        self._target_position = None
        self._position = np.zeros((2))

        u = np.ones(2)
        self.action_space = spaces.Box(-u, u, dtype=np.float32)

        o = self.boundary_dist * np.ones(2)
        self.obs_range = spaces.Box(-o, o, dtype='float32')
        self.observation_space = spaces.Dict([
            ('observation', self.obs_range),
            ('desired_goal', self.obs_range),
            ('achieved_goal', self.obs_range),
            ('state_observation', self.obs_range),
            ('state_desired_goal', self.obs_range),
            ('state_achieved_goal', self.obs_range),
        ])

        self.drawer = None

    def step(self, velocities):
        velocities = np.clip(velocities, a_min=-1, a_max=1)
        new_position = self._position + velocities
        for wall in self.walls:
            new_position = wall.handle_collision(self._position, new_position)
        self._position = new_position
        self._position = np.clip(
            self._position,
            a_min=-self.boundary_dist,
            a_max=self.boundary_dist,
        )
        distance_to_target = np.linalg.norm(self._position -
                                            self._target_position)
        is_success = distance_to_target < self.target_radius

        ob = self._get_obs()
        reward = self.compute_reward(velocities, ob)
        info = {
            'radius': self.target_radius,
            'target_position': self._target_position,
            'distance_to_target': distance_to_target,
            'velocity': velocities,
            'speed': np.linalg.norm(velocities),
            'is_success': is_success,
        }
        done = False
        return ob, reward, done, info

    def _sample_goal(self):
        return np.random.uniform(size=2,
                                 low=-self.max_target_distance,
                                 high=self.max_target_distance)

    def reset(self):
        self._target_position = np.random.uniform(
            size=2,
            low=-self.max_target_distance,
            high=self.max_target_distance)
        if self.randomize_position_on_reset:
            self._position = np.random.uniform(size=2,
                                               low=-self.boundary_dist,
                                               high=self.boundary_dist)
        return self._get_obs()

    def seed(self, s):
        """Do nothing for seed"""
        pass

    def _get_obs(self):
        return dict(
            observation=self._position.copy(),
            desired_goal=self._target_position.copy(),
            achieved_goal=self._position.copy(),
            state_observation=self._position.copy(),
            state_desired_goal=self._target_position.copy(),
            state_achieved_goal=self._position.copy(),
        )

    def compute_rewards(self, actions, obs):
        achieved_goals = obs['observation']
        desired_goals = obs['desired_goal']
        d = np.linalg.norm(achieved_goals - desired_goals, axis=-1)
        if self.reward_type == "sparse":
            return -(d > self.target_radius).astype(np.float32)
        if self.reward_type == "dense":
            return -d

    def get_goal(self):
        return {
            'desired_goal': self._target_position.copy(),
            'state_desired_goal': self._target_position.copy(),
        }

    def sample_goals(self, batch_size):
        if not self.fixed_goal is None:
            goals = np.repeat(self.fixed_goal.copy()[None], batch_size, 0)
        else:
            goals = np.random.uniform(
                self.obs_range.low,
                self.obs_range.high,
                size=(batch_size, self.obs_range.low.size),
            )
        return {
            'desired_goal': goals,
            'state_desired_goal': goals,
        }

    def set_position(self, pos):
        self._position[0] = pos[0]
        self._position[1] = pos[1]

    """Functions for ImageEnv wrapper"""

    def get_image(self):
        """Returns a black and white image"""
        self.render()
        img = self.drawer.get_image()
        # img = img / 255.0
        r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
        # img = (-r - g + b).flatten() # GREEN ignored for visualization
        img = (-r + b).flatten()
        return img

    def set_to_goal(self, goal_dict):
        goal = goal_dict["desired_goal"]
        self._position = goal
        self._target_position = goal

    def get_env_state(self):
        return self._get_obs()

    def set_env_state(self, state):
        position = state["state_observation"]
        goal = state["state_desired_goal"]
        self._position = position
        self._target_position = goal

    def render(self, close=False):
        if close:
            self.drawer = None
            return

        if self.drawer is None or self.drawer.terminated:
            self.drawer = PygameViewer(
                self.render_size,
                self.render_size,
                x_bounds=(-self.boundary_dist, self.boundary_dist),
                y_bounds=(-self.boundary_dist, self.boundary_dist),
                render_onscreen=self.render_onscreen,
            )

        self.drawer.fill(Color('white'))
        self.drawer.draw_solid_circle(
            self._target_position,
            self.target_radius,
            Color('green'),
        )
        self.drawer.draw_solid_circle(
            self._position,
            self.ball_radius,
            Color('blue'),
        )

        for wall in self.walls:
            self.drawer.draw_segment(
                wall.endpoint1,
                wall.endpoint2,
                Color('black'),
            )

        self.drawer.render()
        self.drawer.tick(self.render_dt_msec)

    """Static visualization/utility methods"""

    @staticmethod
    def true_model(state, action):
        velocities = np.clip(action, a_min=-1, a_max=1)
        position = state
        new_position = position + velocities
        return np.clip(
            new_position,
            a_min=-Point2DEnv.boundary_dist,
            a_max=Point2DEnv.boundary_dist,
        )

    @staticmethod
    def true_states(state, actions):
        real_states = [state]
        for action in actions:
            next_state = Point2DEnv.true_model(state, action)
            real_states.append(next_state)
            state = next_state
        return real_states

    @staticmethod
    def plot_trajectory(ax, states, actions, goal=None):
        assert len(states) == len(actions) + 1
        x = states[:, 0]
        y = -states[:, 1]
        num_states = len(states)
        plasma_cm = plt.get_cmap('plasma')
        for i, state in enumerate(states):
            color = plasma_cm(float(i) / num_states)
            ax.plot(
                state[0],
                -state[1],
                marker='o',
                color=color,
                markersize=10,
            )

        actions_x = actions[:, 0]
        actions_y = -actions[:, 1]

        ax.quiver(x[:-1],
                  y[:-1],
                  x[1:] - x[:-1],
                  y[1:] - y[:-1],
                  scale_units='xy',
                  angles='xy',
                  scale=1,
                  width=0.005)
        ax.quiver(
            x[:-1],
            y[:-1],
            actions_x,
            actions_y,
            scale_units='xy',
            angles='xy',
            scale=1,
            color='r',
            width=0.0035,
        )
        ax.plot(
            [
                -Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            [
                Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            color='k',
            linestyle='-',
        )
        ax.plot(
            [
                Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            [
                Point2DEnv.boundary_dist,
                Point2DEnv.boundary_dist,
            ],
            color='k',
            linestyle='-',
        )
        ax.plot(
            [
                Point2DEnv.boundary_dist,
                Point2DEnv.boundary_dist,
            ],
            [
                Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            color='k',
            linestyle='-',
        )
        ax.plot(
            [
                Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            [
                -Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            color='k',
            linestyle='-',
        )

        if goal is not None:
            ax.plot(goal[0], -goal[1], marker='*', color='g', markersize=15)
        ax.set_ylim(
            -Point2DEnv.boundary_dist - 1,
            Point2DEnv.boundary_dist + 1,
        )
        ax.set_xlim(
            -Point2DEnv.boundary_dist - 1,
            Point2DEnv.boundary_dist + 1,
        )
Example #2
0
class Point2DEnv(MultitaskEnv, Serializable):
    """
    A little 2D point whose life goal is to reach a target.
    """

    def __init__(
            self,
            render_dt_msec=0,
            action_l2norm_penalty=0,  # disabled for now
            render_onscreen=True,
            render_size=84,
            reward_type="dense",
            target_radius=0.5,
            boundary_dist=4,
            ball_radius=0.25,
            walls=None,
            fixed_goal=None,
            randomize_position_on_reset=True,
            images_are_rgb=False,  # else black and white
            show_goal=True,
            action_limit=1,
            initial_position=(0, 0),
            **kwargs
    ):
        if walls is None:
            walls = []
        if len(kwargs) > 0:
            import logging
            LOGGER = logging.getLogger(__name__)
            LOGGER.log(logging.WARNING, "WARNING, ignoring kwargs:", kwargs)
        self.quick_init(locals())
        self.render_dt_msec = render_dt_msec
        self.action_l2norm_penalty = action_l2norm_penalty
        self.render_onscreen = render_onscreen
        self.render_size = render_size
        self.reward_type = reward_type
        self.target_radius = target_radius
        self.boundary_dist = boundary_dist
        self.ball_radius = ball_radius
        self.walls = walls
        self.fixed_goal = fixed_goal
        self.randomize_position_on_reset = randomize_position_on_reset
        self.images_are_rgb = images_are_rgb
        self.show_goal = show_goal
        self.action_limit = action_limit

        self._max_episode_steps = 50
        self.max_target_distance = self.boundary_dist - self.target_radius

        self._target_position = None
        self._position = np.zeros((2))

        u = self.action_limit * np.ones(2)
        self.action_space = spaces.Box(-u, u, dtype=np.float32)

        o = self.boundary_dist * np.ones(2)
        self.obs_range = spaces.Box(-o, o, dtype='float32')
        self.observation_space = spaces.Dict([
            ('observation', self.obs_range),
            ('desired_goal', self.obs_range),
            ('achieved_goal', self.obs_range),
            ('state_observation', self.obs_range),
            ('state_desired_goal', self.obs_range),
            ('state_achieved_goal', self.obs_range),
        ])
        self.observation_space.shape = (2,)   # hack
        self._step = 0
        self.initial_position = initial_position

        self.drawer = None

    def step(self, velocities):
        velocities = np.clip(velocities, a_min=-self.action_limit, a_max=self.action_limit)
        new_position = self._position + velocities
        for wall in self.walls:
            new_position = wall.handle_collision(
                self._position, new_position
            )
        self._position = new_position
        self._position = np.clip(
            self._position,
            a_min=-self.boundary_dist,
            a_max=self.boundary_dist,
        )
        distance_to_target = np.linalg.norm(
            self._position - self._target_position
        )
        is_success = distance_to_target < self.target_radius

        ob = self._get_obs()
        reward = self.compute_reward(velocities, ob)
        info = {
            'radius': self.target_radius,
            'target_position': self._target_position,
            'distance_to_target': distance_to_target,
            'velocity': velocities,
            'speed': np.linalg.norm(velocities),
            'is_success': is_success,
        }
        self._step += 1
        done = False
        return ob, reward, done, info

    def _sample_goal(self):
        return np.random.uniform(
            size=2, low=-self.max_target_distance, high=self.max_target_distance
        )

    def reset(self):
        self._step = 0
        if self.fixed_goal:
            self._target_position = np.array([0, 0])
        else:
            self._target_position = np.random.uniform(
                size=2, low=-self.max_target_distance, high=self.max_target_distance
            )
        if self.randomize_position_on_reset:
            self._position = np.random.uniform(
                size=2, low=-self.boundary_dist, high=self.boundary_dist
            )
        else:
            self._position = np.array(self.initial_position)
        return self._get_obs()

    def _position_inside_wall(self, pos):
        for wall in self.walls:
            if wall.contains_point(pos):
                return True
        return False

    def _get_obs(self):
        return OrderedDict(
            observation=self._position.copy(),
            desired_goal=self._target_position.copy(),
            achieved_goal=self._position.copy(),
            state_observation=self._position.copy(),
            state_desired_goal=self._target_position.copy(),
            state_achieved_goal=self._position.copy(),
        )

    def compute_rewards(self, actions, obs):
        achieved_goals = obs['state_observation']
        desired_goals = obs['state_desired_goal']
        d = np.linalg.norm(achieved_goals - desired_goals, axis=-1)
        if self.reward_type == "sparse":
            return -(d > self.target_radius).astype(np.float32)
        if self.reward_type == "dense":
            return -d

    def get_diagnostics(self, paths, prefix=''):
        statistics = OrderedDict()
        for stat_name in [
            'distance_to_target',
            'is_success',
        ]:
            stat_name = stat_name
            stat = get_stat_in_paths(paths, 'env_infos', stat_name)
            statistics.update(create_stats_ordered_dict(
                '%s%s' % (prefix, stat_name),
                stat,
                always_show_all_stats=True,
                ))
            statistics.update(create_stats_ordered_dict(
                'Final %s%s' % (prefix, stat_name),
                [s[-1] for s in stat],
                always_show_all_stats=True,
                ))
        return statistics

    def get_goal(self):
        return {
            'desired_goal': self._target_position.copy(),
            'state_desired_goal': self._target_position.copy(),
        }

    def _sample_position(self, low, high, realistic=True):
        pos = np.random.uniform(low, high)
        if realistic:
            while self._position_inside_wall(pos) is True:
                pos = np.random.uniform(low, high)
        return pos

    def sample_goals(self, batch_size):
        if not self.fixed_goal is None:
            goals = np.repeat(
                self.fixed_goal.copy()[None],
                batch_size,
                0
            )
        else:
            goals = np.zeros((batch_size, self.obs_range.low.size))
            for b in range(batch_size):
                goals[b, :] = self._sample_position(self.obs_range.low,
                                         self.obs_range.high,)
                                         # realistic=self.sample_realistic_goals)
            # goals = np.random.uniform(
            #     self.obs_range.low,
            #     self.obs_range.high,
            #     size=(batch_size, self.obs_range.low.size),
            # )
        return {
            'desired_goal': goals,
            'state_desired_goal': goals,
        }

    def set_position(self, pos):
        self._position[0] = pos[0]
        self._position[1] = pos[1]

    """Functions for ImageEnv wrapper"""

    def render(self, mode='human', **kwargs):
        width, height = 256, 256
        """Returns a black and white image"""
        if width is not None:
            if width != height:
                raise NotImplementedError()
            if width != self.render_size:
                self.drawer = PygameViewer(
                    screen_width=width,
                    screen_height=height,
                    x_bounds=(-self.boundary_dist, self.boundary_dist),
                    y_bounds=(-self.boundary_dist, self.boundary_dist),
                    render_onscreen=self.render_onscreen,
                )
                self.render_size = width
        self._render()
        img = self.drawer.get_image()
        return np.rot90(img, k=1, axes=(0, 1))
        # if self.images_are_rgb:
        #     return img.transpose().flatten()
        # else:
        #     r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
        #     img = (-r + b).transpose().flatten()
        #     return img

    def set_to_goal(self, goal_dict):
        goal = goal_dict["state_desired_goal"]
        self._position = goal
        self._target_position = goal

    def get_env_state(self):
        return self._get_obs()

    def set_env_state(self, state):
        position = state["state_observation"]
        goal = state["state_desired_goal"]
        self._position = position
        self._target_position = goal

    def _render(self, close=False):

        if close:
            self.drawer = None
            return

        if self.drawer is None or self.drawer.terminated:
            self.drawer = PygameViewer(
                self.render_size,
                self.render_size,
                x_bounds=(-self.boundary_dist, self.boundary_dist),
                y_bounds=(-self.boundary_dist, self.boundary_dist),
                render_onscreen=self.render_onscreen,
            )

        self.drawer.fill(Color('white'))

        self.drawer.draw_solid_circle(
            np.array([10, 10]),
            1,
            Color('red')
        )

        if self.show_goal:
            self.drawer.draw_solid_circle(
                self._target_position,
                self.target_radius,
                Color('green'),
            )
        self.drawer.draw_solid_circle(
            self._position,
            self.ball_radius,
            Color('blue'),
        )

        for wall in self.walls:
            self.drawer.draw_segment(
                wall.endpoint1,
                wall.endpoint2,
                Color('black'),
            )
            self.drawer.draw_segment(
                wall.endpoint2,
                wall.endpoint3,
                Color('black'),
            )
            self.drawer.draw_segment(
                wall.endpoint3,
                wall.endpoint4,
                Color('black'),
            )
            self.drawer.draw_segment(
                wall.endpoint4,
                wall.endpoint1,
                Color('black'),
            )

        self.drawer.render()
        self.drawer.tick(self.render_dt_msec)

    def get_diagnostics(self, paths, prefix=''):
        statistics = OrderedDict()
        for stat_name in [
            'distance_to_target',
        ]:
            stat_name = stat_name
            stat = get_stat_in_paths(paths, 'env_infos', stat_name)
            statistics.update(create_stats_ordered_dict(
                '%s%s' % (prefix, stat_name),
                stat,
                always_show_all_stats=True,
                ))
            statistics.update(create_stats_ordered_dict(
                'Final %s%s' % (prefix, stat_name),
                [s[-1] for s in stat],
                always_show_all_stats=True,
                ))
        return statistics

    """Static visualization/utility methods"""

    @staticmethod
    def true_model(state, action):
        velocities = np.clip(action, a_min=-1, a_max=1)
        position = state
        new_position = position + velocities
        return np.clip(
            new_position,
            a_min=-Point2DEnv.boundary_dist,
            a_max=Point2DEnv.boundary_dist,
        )

    @staticmethod
    def true_states(state, actions):
        real_states = [state]
        for action in actions:
            next_state = Point2DEnv.true_model(state, action)
            real_states.append(next_state)
            state = next_state
        return real_states

    @staticmethod
    def plot_trajectory(ax, states, actions, goal=None):
        assert len(states) == len(actions) + 1
        x = states[:, 0]
        y = -states[:, 1]
        num_states = len(states)
        plasma_cm = plt.get_cmap('plasma')
        for i, state in enumerate(states):
            color = plasma_cm(float(i) / num_states)
            ax.plot(state[0], -state[1],
                    marker='o', color=color, markersize=10,
                    )

        actions_x = actions[:, 0]
        actions_y = -actions[:, 1]

        ax.quiver(x[:-1], y[:-1], x[1:] - x[:-1], y[1:] - y[:-1],
                  scale_units='xy', angles='xy', scale=1, width=0.005)
        ax.quiver(x[:-1], y[:-1], actions_x, actions_y, scale_units='xy',
                  angles='xy', scale=1, color='r',
                  width=0.0035, )
        ax.plot(
            [
                -Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            [
                Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            color='k', linestyle='-',
        )
        ax.plot(
            [
                Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            [
                Point2DEnv.boundary_dist,
                Point2DEnv.boundary_dist,
            ],
            color='k', linestyle='-',
        )
        ax.plot(
            [
                Point2DEnv.boundary_dist,
                Point2DEnv.boundary_dist,
            ],
            [
                Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            color='k', linestyle='-',
        )
        ax.plot(
            [
                Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            [
                -Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            color='k', linestyle='-',
        )

        if goal is not None:
            ax.plot(goal[0], -goal[1], marker='*', color='g', markersize=15)
        ax.set_ylim(
            -Point2DEnv.boundary_dist - 1,
            Point2DEnv.boundary_dist + 1,
        )
        ax.set_xlim(
            -Point2DEnv.boundary_dist - 1,
            Point2DEnv.boundary_dist + 1,
        )

    def initialize_camera(self, init_fctn):
        pass
Example #3
0
class Point2DEnv(MultitaskEnv, Serializable):
    """
    A little 2D point whose life goal is to reach a target.
    """
    def __init__(
            self,
            render_dt_msec=0,
            action_l2norm_penalty=0,  # disabled for now
            render_onscreen=False,
            render_size=200,
            render_target=True,
            reward_type="dense",
            norm_order=2,
            action_scale=1.0,
            target_radius=0.50,
            boundary_dist=4,
            ball_radius=0.50,
            walls=[],
            fixed_goal=None,
            goal_low=None,
            goal_high=None,
            ball_low=None,
            ball_high=None,
            randomize_position_on_reset=True,
            sample_realistic_goals=False,
            images_are_rgb=False,  # else black and white
            show_goal=True,
            v_func_heatmap_bounds=None,
            **kwargs):
        if walls is None:
            walls = []
        if len(kwargs) > 0:
            import logging
            LOGGER = logging.getLogger(__name__)
            LOGGER.log(logging.WARNING, "WARNING, ignoring kwargs:", kwargs)
        self.quick_init(locals())
        self.render_dt_msec = render_dt_msec
        self.action_l2norm_penalty = action_l2norm_penalty
        self.render_onscreen = render_onscreen
        self.render_size = render_size
        self.render_target = render_target
        self.reward_type = reward_type
        self.norm_order = norm_order
        self.action_scale = action_scale
        self.target_radius = target_radius
        self.boundary_dist = boundary_dist
        self.ball_radius = ball_radius
        self.walls = walls
        self.fixed_goal = fixed_goal
        self.randomize_position_on_reset = randomize_position_on_reset
        self.sample_realistic_goals = sample_realistic_goals
        if self.fixed_goal is not None:
            self.fixed_goal = np.array(self.fixed_goal)
        self.images_are_rgb = images_are_rgb
        self.show_goal = show_goal

        self.max_target_distance = self.boundary_dist - self.target_radius

        self._target_position = None
        self._position = np.zeros((2))

        u = np.ones(2)
        self.action_space = spaces.Box(-u, u, dtype=np.float32)

        o = self.boundary_dist * np.ones(2)
        self.obs_range = spaces.Box(-o, o, dtype='float32')

        if goal_low is None:
            goal_low = -o
        if goal_high is None:
            goal_high = o
        self.goal_range = spaces.Box(np.array(goal_low),
                                     np.array(goal_high),
                                     dtype='float32')

        if ball_low is None:
            ball_low = -o
        if ball_high is None:
            ball_high = o
        self.ball_range = spaces.Box(np.array(ball_low),
                                     np.array(ball_high),
                                     dtype='float32')

        self.observation_space = spaces.Dict([
            ('observation', self.obs_range),
            ('desired_goal', self.goal_range),
            ('achieved_goal', self.obs_range),
            ('state_observation', self.obs_range),
            ('state_desired_goal', self.goal_range),
            ('state_achieved_goal', self.obs_range),
        ])

        self.drawer = None
        self.subgoals = None

        self.v_func_heatmap_bounds = v_func_heatmap_bounds

    def step(self, velocities):
        assert self.action_scale <= 1.0
        velocities = np.clip(velocities, a_min=-1, a_max=1) * self.action_scale
        new_position = self._position + velocities
        for wall in self.walls:
            new_position = wall.handle_collision(self._position, new_position)
        self._position = new_position
        self._position = np.clip(
            self._position,
            a_min=-self.boundary_dist,
            a_max=self.boundary_dist,
        )
        distance_to_target = np.linalg.norm(self._position -
                                            self._target_position)
        below_wall = self._position[1] >= 2.0
        is_success = distance_to_target < 1.0
        is_success_2 = distance_to_target < 1.5
        is_success_3 = distance_to_target < 1.75

        ob = self._get_obs()
        reward = self.compute_reward(velocities, ob)
        info = {
            'radius': self.target_radius,
            'target_position': self._target_position,
            'distance_to_target': distance_to_target,
            'velocity': velocities,
            'speed': np.linalg.norm(velocities),
            'is_success': is_success,
            'is_success_2': is_success_2,
            'is_success_3': is_success_3,
            'below_wall': below_wall,
        }
        done = False
        return ob, reward, done, info

    def initialize_camera(self, camera):
        pass

    def reset(self):
        self._target_position = self.sample_goal_for_rollout(
        )['state_desired_goal']
        if self.randomize_position_on_reset:
            self._position = self._sample_position(self.ball_range.low,
                                                   self.ball_range.high,
                                                   realistic=True)

        self.subgoals = None
        return self._get_obs()

    def _position_outside_arena(self, pos):
        return not self.obs_range.contains(pos)

    def _position_inside_wall(self, pos):
        for wall in self.walls:
            if wall.contains_point(pos):
                return True
        return False

    def realistic_state_np(self, state):
        return not self._position_inside_wall(state)

    def realistic_goals(self, g, use_double=False):
        collision = None
        for wall in self.walls:
            wall_collision = wall.contains_points_pytorch(g)
            wall_collision_result = wall_collision[:, 0]
            for i in range(wall_collision.shape[1]):
                wall_collision_result = wall_collision_result * wall_collision[:,
                                                                               i]

            if collision is None:
                collision = wall_collision_result
            else:
                collision = collision | wall_collision_result

        result = collision ^ 1

        # g_np = ptu.get_numpy(g)
        # for i in range(g_np.shape[0]):
        #     print(not self._position_inside_wall(g_np[i]), ptu.get_numpy(result[i]))

        return result.float()

    def _sample_position(self, low, high, realistic=True):
        pos = np.random.uniform(low, high)
        if realistic:
            while self._position_inside_wall(pos) is True:
                pos = np.random.uniform(low, high)
        return pos

    def _get_obs(self):
        return dict(
            observation=self._position.copy(),
            desired_goal=self._target_position.copy(),
            achieved_goal=self._position.copy(),
            state_observation=self._position.copy(),
            state_desired_goal=self._target_position.copy(),
            state_achieved_goal=self._position.copy(),
        )

    def compute_rewards(self, actions, obs, prev_obs=None):
        achieved_goals = obs['state_achieved_goal']
        desired_goals = obs['state_desired_goal']
        d = np.linalg.norm(achieved_goals - desired_goals,
                           ord=self.norm_order,
                           axis=-1)
        if self.reward_type == "sparse":
            return -(d > self.target_radius).astype(np.float32)
        elif self.reward_type == "dense":
            return -d
        elif self.reward_type == 'vectorized_dense':
            return -np.abs(achieved_goals - desired_goals)

    def get_diagnostics(self, paths, prefix=''):
        statistics = OrderedDict()
        for stat_name in [
                'distance_to_target',
                'below_wall',
                'is_success',
                'is_success_2',
                'is_success_3',
                'velocity',
                'speed',
        ]:
            stat_name = stat_name
            stat = get_stat_in_paths(paths, 'env_infos', stat_name)
            statistics.update(
                create_stats_ordered_dict(
                    '%s%s' % (prefix, stat_name),
                    stat,
                    always_show_all_stats=True,
                ))
            statistics.update(
                create_stats_ordered_dict(
                    'Final %s%s' % (prefix, stat_name),
                    [s[-1] for s in stat],
                    always_show_all_stats=True,
                ))
        return statistics

    def get_goal(self):
        return {
            'desired_goal': self._target_position.copy(),
            'state_desired_goal': self._target_position.copy(),
        }

    def set_goal(self, goal):
        self._target_position = goal['state_desired_goal']

    def sample_goal_for_rollout(self):
        if not self.fixed_goal is None:
            goal = np.copy(self.fixed_goal)
        else:
            goal = self._sample_position(self.goal_range.low,
                                         self.goal_range.high,
                                         realistic=self.sample_realistic_goals)
        return {
            'desired_goal': goal,
            'state_desired_goal': goal,
        }

    def sample_goals(self, batch_size):
        # goals = np.random.uniform(
        #     self.obs_range.low,
        #     self.obs_range.high,
        #     size=(batch_size, self.obs_range.low.size),
        # )
        goals = np.array([
            self._sample_position(self.obs_range.low,
                                  self.obs_range.high,
                                  realistic=self.sample_realistic_goals)
            for _ in range(batch_size)
        ])
        return {
            'desired_goal': goals,
            'state_desired_goal': goals,
        }

    def set_position(self, pos):
        self._position[0] = pos[0]
        self._position[1] = pos[1]

    """Functions for ImageEnv wrapper"""

    def get_image(self, width=None, height=None):
        """Returns a black and white image"""
        if width is not None:
            if width != height:
                raise NotImplementedError()
            if width != self.render_size:
                self.drawer = PygameViewer(
                    screen_width=width,
                    screen_height=height,
                    x_bounds=(-self.boundary_dist - self.ball_radius,
                              self.boundary_dist + self.ball_radius),
                    y_bounds=(-self.boundary_dist - self.ball_radius,
                              self.boundary_dist + self.ball_radius),
                    render_onscreen=self.render_onscreen,
                )
                self.render_size = width
        self.render()
        img = self.drawer.get_image()
        if self.images_are_rgb:
            return img.transpose((1, 0, 2))
        else:
            r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
            img = (-r + b)
            return img

    def update_subgoals(self,
                        subgoals,
                        subgoals_reproj=None,
                        subgoal_v_vals=None):
        self.subgoals = subgoals

    def set_to_goal(self, goal_dict):
        goal = goal_dict["state_desired_goal"]
        self._position = goal
        self._target_position = goal

    def get_env_state(self):
        return self._get_obs()

    def set_env_state(self, state):
        position = state["state_observation"]
        goal = state["state_desired_goal"]
        self._position = position
        self._target_position = goal

    def get_states_sweep(self, nx, ny):
        x = np.linspace(-4, 4, nx)
        y = np.linspace(-4, 4, ny)
        xv, yv = np.meshgrid(x, y)
        states = np.stack((xv, yv), axis=2).reshape((-1, 2))
        return states

    def render(self, close=False):
        if close:
            self.drawer = None
            return

        if self.drawer is None or self.drawer.terminated:
            self.drawer = PygameViewer(
                self.render_size,
                self.render_size,
                x_bounds=(-self.boundary_dist - self.ball_radius,
                          self.boundary_dist + self.ball_radius),
                y_bounds=(-self.boundary_dist - self.ball_radius,
                          self.boundary_dist + self.ball_radius),
                render_onscreen=self.render_onscreen,
            )
        self.drawer.fill(Color('white'))
        if self.render_target:
            self.drawer.draw_solid_circle(
                self._target_position,
                self.target_radius,
                Color('green'),
            )
        self.drawer.draw_solid_circle(
            self._position,
            self.ball_radius,
            Color('blue'),
        )

        if self.subgoals is not None:
            for goal in self.subgoals:
                self.drawer.draw_solid_circle(
                    goal,
                    self.ball_radius + 0.1,
                    Color('red'),
                )
            for p1, p2 in zip(
                    np.concatenate(([self._position], self.subgoals[:-1]),
                                   axis=0), self.subgoals):
                self.drawer.draw_segment(p1, p2, Color(100, 0, 0, 10))

        for wall in self.walls:
            self.drawer.draw_segment(
                wall.endpoint1,
                wall.endpoint2,
                Color('black'),
            )
            self.drawer.draw_segment(
                wall.endpoint2,
                wall.endpoint3,
                Color('black'),
            )
            self.drawer.draw_segment(
                wall.endpoint3,
                wall.endpoint4,
                Color('black'),
            )
            self.drawer.draw_segment(
                wall.endpoint4,
                wall.endpoint1,
                Color('black'),
            )

        self.drawer.render()
        self.drawer.tick(self.render_dt_msec)

    """Static visualization/utility methods"""

    @staticmethod
    def true_model(state, action):
        velocities = np.clip(action, a_min=-1, a_max=1)
        position = state
        new_position = position + velocities
        return np.clip(
            new_position,
            a_min=-Point2DEnv.boundary_dist,
            a_max=Point2DEnv.boundary_dist,
        )

    @staticmethod
    def true_states(state, actions):
        real_states = [state]
        for action in actions:
            next_state = Point2DEnv.true_model(state, action)
            real_states.append(next_state)
            state = next_state
        return real_states

    @staticmethod
    def plot_trajectory(ax, states, actions, goal=None):
        assert len(states) == len(actions) + 1
        x = states[:, 0]
        y = -states[:, 1]
        num_states = len(states)
        plasma_cm = plt.get_cmap('plasma')
        for i, state in enumerate(states):
            color = plasma_cm(float(i) / num_states)
            ax.plot(
                state[0],
                -state[1],
                marker='o',
                color=color,
                markersize=10,
            )

        actions_x = actions[:, 0]
        actions_y = -actions[:, 1]

        ax.quiver(x[:-1],
                  y[:-1],
                  x[1:] - x[:-1],
                  y[1:] - y[:-1],
                  scale_units='xy',
                  angles='xy',
                  scale=1,
                  width=0.005)
        ax.quiver(
            x[:-1],
            y[:-1],
            actions_x,
            actions_y,
            scale_units='xy',
            angles='xy',
            scale=1,
            color='r',
            width=0.0035,
        )
        ax.plot(
            [
                -Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            [
                Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            color='k',
            linestyle='-',
        )
        ax.plot(
            [
                Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            [
                Point2DEnv.boundary_dist,
                Point2DEnv.boundary_dist,
            ],
            color='k',
            linestyle='-',
        )
        ax.plot(
            [
                Point2DEnv.boundary_dist,
                Point2DEnv.boundary_dist,
            ],
            [
                Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            color='k',
            linestyle='-',
        )
        ax.plot(
            [
                Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            [
                -Point2DEnv.boundary_dist,
                -Point2DEnv.boundary_dist,
            ],
            color='k',
            linestyle='-',
        )

        if goal is not None:
            ax.plot(goal[0], -goal[1], marker='*', color='g', markersize=15)
        ax.set_ylim(
            -Point2DEnv.boundary_dist - 1,
            Point2DEnv.boundary_dist + 1,
        )
        ax.set_xlim(
            -Point2DEnv.boundary_dist - 1,
            Point2DEnv.boundary_dist + 1,
        )

    def initialize_camera(self, init_fctn):
        pass