예제 #1
0
class GripperTester:
    """
    A class that is used to test gripper

    Args:
        gripper (GripperModel): A gripper instance to be tested
        pos (str): (x y z) position to place the gripper in string form, e.g. '0 0 0.3'
        quat (str): rotation to apply to gripper in string form, e.g. '0 0 1 0' to flip z axis
        gripper_low_pos (float): controls the gipper y position, larger -> higher
        gripper_high_pos (float): controls the gipper y high position larger -> higher,
            must be larger than gripper_low_pos
        box_size (None or 3-tuple of int): the size of the box to grasp, None defaults to [0.02, 0.02, 0.02]
        box_density (int): the density of the box to grasp
        step_time (int): the interval between two gripper actions
        render (bool): if True, show rendering
    """

    def __init__(
        self,
        gripper,
        pos,
        quat,
        gripper_low_pos,
        gripper_high_pos,
        box_size=None,
        box_density=10000,
        step_time=400,
        render=True,
    ):
        # define viewer
        self.viewer = None

        world = MujocoWorldBase()
        # Add a table
        arena = TableArena(table_full_size=(0.4, 0.4, 0.1), table_offset=(0, 0, 0.1), has_legs=False)
        world.merge(arena)

        # Add a gripper
        self.gripper = gripper
        # Create another body with a slider joint to which we'll add this gripper
        gripper_body = ET.Element("body")
        gripper_body.set("pos", pos)
        gripper_body.set("quat", quat)  # flip z
        gripper_body.append(new_joint(name="gripper_z_joint", type="slide", axis="0 0 -1", damping="50"))
        # Add all gripper bodies to this higher level body
        for body in gripper.worldbody:
            gripper_body.append(body)
        # Merge the all of the gripper tags except its bodies
        world.merge(gripper, merge_body=None)
        # Manually add the higher level body we created
        world.worldbody.append(gripper_body)
        # Create a new actuator to control our slider joint
        world.actuator.append(new_actuator(joint="gripper_z_joint", act_type="position", name="gripper_z", kp="500"))

        # Add an object for grasping
        # density is in units kg / m3
        TABLE_TOP = [0, 0, 0.09]
        if box_size is None:
            box_size = [0.02, 0.02, 0.02]
        box_size = np.array(box_size)
        self.cube = BoxObject(
            name="object", size=box_size, rgba=[1, 0, 0, 1], friction=[1, 0.005, 0.0001], density=box_density
        )
        object_pos = np.array(TABLE_TOP + box_size * [0, 0, 1])
        mujoco_object = self.cube.get_obj()
        # Set the position of this object
        mujoco_object.set("pos", array_to_string(object_pos))
        # Add our object to the world body
        world.worldbody.append(mujoco_object)

        # add reference objects for x and y axes
        x_ref = BoxObject(
            name="x_ref", size=[0.01, 0.01, 0.01], rgba=[0, 1, 0, 1], obj_type="visual", joints=None
        ).get_obj()
        x_ref.set("pos", "0.2 0 0.105")
        world.worldbody.append(x_ref)
        y_ref = BoxObject(
            name="y_ref", size=[0.01, 0.01, 0.01], rgba=[0, 0, 1, 1], obj_type="visual", joints=None
        ).get_obj()
        y_ref.set("pos", "0 0.2 0.105")
        world.worldbody.append(y_ref)

        self.world = world
        self.render = render
        self.simulation_ready = False
        self.step_time = step_time
        self.cur_step = 0
        if gripper_low_pos > gripper_high_pos:
            raise ValueError(
                "gripper_low_pos {} is larger " "than gripper_high_pos {}".format(gripper_low_pos, gripper_high_pos)
            )
        self.gripper_low_pos = gripper_low_pos
        self.gripper_high_pos = gripper_high_pos

    def start_simulation(self):
        """
        Starts simulation of the test world
        """
        model = self.world.get_model(mode="mujoco_py")

        self.sim = MjSim(model)
        if self.render:
            self.viewer = MjViewer(self.sim)
        self.sim_state = self.sim.get_state()

        # For gravity correction
        gravity_corrected = ["gripper_z_joint"]
        self._gravity_corrected_qvels = [self.sim.model.get_joint_qvel_addr(x) for x in gravity_corrected]

        self.gripper_z_id = self.sim.model.actuator_name2id("gripper_z")
        self.gripper_z_is_low = False

        self.gripper_actuator_ids = [self.sim.model.actuator_name2id(x) for x in self.gripper.actuators]

        self.gripper_is_closed = True

        self.object_id = self.sim.model.body_name2id(self.cube.root_body)
        object_default_pos = self.sim.data.body_xpos[self.object_id]
        self.object_default_pos = np.array(object_default_pos, copy=True)

        self.reset()
        self.simulation_ready = True

    def reset(self):
        """
        Resets the simulation to the initial state
        """
        self.sim.set_state(self.sim_state)
        self.cur_step = 0

    def close(self):
        """
        Close the viewer if it exists
        """
        if self.viewer is not None:
            self.viewer.close()

    def step(self):
        """
        Forward the simulation by one timestep

        Raises:
            RuntimeError: if start_simulation is not yet called.
        """
        if not self.simulation_ready:
            raise RuntimeError("Call start_simulation before calling step")
        if self.gripper_z_is_low:
            self.sim.data.ctrl[self.gripper_z_id] = self.gripper_low_pos
        else:
            self.sim.data.ctrl[self.gripper_z_id] = self.gripper_high_pos
        if self.gripper_is_closed:
            self._apply_gripper_action(1)
        else:
            self._apply_gripper_action(-1)
        self._apply_gravity_compensation()
        self.sim.step()
        if self.render:
            self.viewer.render()
        self.cur_step += 1

    def _apply_gripper_action(self, action):
        """
        Applies binary gripper action

        Args:
            action (int): Action to apply. Should be -1 (open) or 1 (closed)
        """
        gripper_action_actual = self.gripper.format_action(np.array([action]))
        # rescale normalized gripper action to control ranges
        ctrl_range = self.sim.model.actuator_ctrlrange[self.gripper_actuator_ids]
        bias = 0.5 * (ctrl_range[:, 1] + ctrl_range[:, 0])
        weight = 0.5 * (ctrl_range[:, 1] - ctrl_range[:, 0])
        applied_gripper_action = bias + weight * gripper_action_actual
        self.sim.data.ctrl[self.gripper_actuator_ids] = applied_gripper_action

    def _apply_gravity_compensation(self):
        """
        Applies gravity compensation to the simulation
        """
        self.sim.data.qfrc_applied[self._gravity_corrected_qvels] = self.sim.data.qfrc_bias[
            self._gravity_corrected_qvels
        ]

    def loop(self, total_iters=1, test_y=False, y_baseline=0.01):
        """
        Performs lower, grip, raise and release actions of a gripper,
                each separated with T timesteps

        Args:
            total_iters (int): Iterations to perform before exiting
            test_y (bool): test if object is lifted
            y_baseline (float): threshold for determining that object is lifted
        """
        seq = [(False, False), (True, False), (True, True), (False, True)]
        for cur_iter in range(total_iters):
            for cur_plan in seq:
                self.gripper_z_is_low, self.gripper_is_closed = cur_plan
                for step in range(self.step_time):
                    self.step()
            if test_y:
                if not self.object_height > y_baseline:
                    raise ValueError(
                        "object is lifed by {}, ".format(self.object_height)
                        + "not reaching the requirement {}".format(y_baseline)
                    )

    @property
    def object_height(self):
        """
        Queries the height (z) of the object compared to on the ground

        Returns:
            float: Object height relative to default (ground) object position
        """
        return self.sim.data.body_xpos[self.object_id][2] - self.object_default_pos[2]
예제 #2
0
class MujocoEnv(gym.Env):
    """Initializes a Mujoco Environment."""
    def __init__(self,
                 has_renderer=False,
                 has_offscreen_renderer=True,
                 render_collision_mesh=False,
                 render_visual_mesh=True,
                 control_freq=10,
                 horizon=1000,
                 ignore_done=False,
                 use_camera_obs=False,
                 camera_name="frontview",
                 camera_height=256,
                 camera_width=256,
                 camera_depth=False,
                 **kwargs):
        """
        Args:

            has_renderer (bool): If true, render the simulation state in 
                a viewer instead of headless mode.

            has_offscreen_renderer (bool): True if using off-screen rendering.

            render_collision_mesh (bool): True if rendering collision meshes 
                in camera. False otherwise.

            render_visual_mesh (bool): True if rendering visual meshes 
                in camera. False otherwise.

            control_freq (float): how many control signals to receive 
                in every simulated second. This sets the amount of simulation time 
                that passes between every action input.

            horizon (int): Every episode lasts for exactly @horizon timesteps.

            ignore_done (bool): True if never terminating the environment (ignore @horizon).

            use_camera_obs (bool): if True, every observation includes a 
                rendered image.

            camera_name (str): name of camera to be rendered. Must be 
                set if @use_camera_obs is True.

            camera_height (int): height of camera frame.

            camera_width (int): width of camera frame.

            camera_depth (bool): True if rendering RGB-D, and RGB otherwise.
        """

        self.seed()
        self._kwargs = kwargs
        self._screen_width = kwargs["screen_width"]
        self._screen_height = kwargs["screen_height"]

        self.has_renderer = has_renderer
        self.has_offscreen_renderer = has_offscreen_renderer
        self.render_collision_mesh = render_collision_mesh
        self.render_visual_mesh = render_visual_mesh
        self.control_freq = control_freq
        self.horizon = horizon
        self.ignore_done = ignore_done
        self.viewer = None
        self.model = None

        # settings for camera observations
        self.use_camera_obs = use_camera_obs
        if self.use_camera_obs and not self.has_offscreen_renderer:
            raise ValueError(
                "Camera observations require an offscreen renderer.")
        self.camera_name = camera_name
        if self.use_camera_obs and self.camera_name is None:
            raise ValueError("Must specify camera name when using camera obs")
        self.camera_height = camera_height
        self.camera_width = camera_width
        self.camera_depth = camera_depth

        self.pid = None
        self._reset_internal()

    def initialize_time(self, control_freq):
        """
        Initializes the time constants used for simulation.
        """
        self.cur_time = 0
        self.model_timestep = self.sim.model.opt.timestep
        if self.model_timestep <= 0:
            raise XMLError("xml model defined non-positive time step")
        self.control_freq = control_freq
        if control_freq <= 0:
            raise SimulationError(
                "control frequency {} is invalid".format(control_freq))
        self.control_timestep = 1. / control_freq

    def _setup_pid(self):
        "Function to setup the pid controller. Should set the self.pid attribute."
        raise NotImplementedError

    def _load_model(self):
        """Loads an xml model, puts it in self.model"""
        pass

    def _get_reference(self):
        """
        Sets up references to important components. A reference is typically an
        index or a list of indices that point to the corresponding elements
        in a flatten array, which is how MuJoCo stores physical simulation data.
        """
        pass

    def reset(self):
        """Resets simulation."""
        # if there is an active viewer window, destroy it
        self._destroy_viewer()
        self._reset_internal()
        self.sim.forward()
        return self._get_observation()

    def _reset_internal(self):
        """Resets simulation internal configurations."""
        # instantiate simulation from MJCF model
        self._load_model()
        self.mjpy_model = self.model.get_model(mode="mujoco_py")
        self.sim = MjSim(self.mjpy_model)
        self.initialize_time(self.control_freq)

        # create visualization screen or renderer
        if self.has_renderer and self.viewer is None:
            self.viewer = MjViewer(self.sim)
            self.viewer.viewer.vopt.geomgroup[0] = (
                1 if self.render_collision_mesh else 0)
            self.viewer.viewer.vopt.geomgroup[
                1] = 1 if self.render_visual_mesh else 0

            # hiding the overlay speeds up rendering significantly
            self.viewer.viewer._hide_overlay = True

            self.viewer.viewer._render_every_frame = True

        elif self.has_offscreen_renderer:
            if self.sim._render_context_offscreen is None:
                render_context = MjRenderContextOffscreen(self.sim)
                self.sim.add_render_context(render_context)
            self.sim._render_context_offscreen.vopt.geomgroup[0] = (
                1 if self.render_collision_mesh else 0)
            self.sim._render_context_offscreen.vopt.geomgroup[1] = (
                1 if self.render_visual_mesh else 0)

        # additional housekeeping
        self.sim_state_initial = self.sim.get_state()
        self._get_reference()
        self.cur_time = 0
        self.timestep = 0
        self.done = False

    def _get_observation(self):
        """Returns an OrderedDict containing observations [(name_string, np.array), ...]."""
        return OrderedDict()

    def step(self, action):
        """Takes a step in simulation with control command @action."""
        if self.done:
            raise ValueError("executing action in terminated episode")

        self.timestep += 1
        self._pre_action(action)
        end_time = self.cur_time + self.control_timestep
        while self.cur_time < end_time:
            if (self.pid is not None):
                self._set_pid_control()
            self.sim.step()
            self.cur_time += self.model_timestep
        reward, done, info = self._post_action(action)
        return self._get_observation(), reward, done, info

    def _set_pid_control(self):
        "Do any processing required with the pid"
        current_qvel = self.sim.data.qvel
        self.sim.data.ctrl[:] = self.pid(current_qvel, self.model_timestep)

    def _pre_action(self, action):
        """Do any preprocessing before taking an action."""
        self.sim.data.ctrl[:] = action

    def _post_action(self, action):
        """Do any housekeeping after taking an action."""
        reward = self.reward(action)

        # done if number of elapsed timesteps is greater than horizon
        self.done = (self.timestep >= self.horizon) and not self.ignore_done
        return reward, self.done, {}

    def reward(self, action):
        """Reward should be a function of state and action."""
        return 0

    def render(self, mode="rgb_array", height=100, width=100):
        """
        Renders to an off-screen window.
        """
        if height == None:
            height = self._screen_height
        if width == None:
            width = self._screen_width

        if mode == "rgb_array":
            camera_obs = self.sim.render(
                camera_name="birdview",
                width=width,
                height=height,
                depth=False,
            )
            #camera_obs = camera_obs[::-1, :, :] / 255.0
            assert np.sum(camera_obs) > 0, "rendering image is blank"
            return camera_obs
        elif mode == "human":
            self.viewer().render()
            return None

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def observation_spec(self):
        """
        Returns an observation as observation specification.

        An alternative design is to return an OrderedDict where the keys
        are the observation names and the values are the shapes of observations.
        We leave this alternative implementation commented out, as we find the
        current design is easier to use in practice.
        """
        observation = self._get_observation()
        # return observation

        observation_spec = OrderedDict()
        for k, v in observation.items():
            observation_spec[k] = v.shape
        return observation_spec

    def action_spec(self):
        """
        Action specification should be implemented in subclasses.

        Action space is represented by a tuple of (low, high), which are two numpy
        vectors that specify the min/max action limits per dimension.
        """
        raise NotImplementedError

    def reset_from_xml_string(self, xml_string):
        """Reloads the environment from an XML description of the environment."""

        # if there is an active viewer window, destroy it
        self.close()

        # load model from xml
        self.mjpy_model = load_model_from_xml(xml_string)

        self.sim = MjSim(self.mjpy_model)
        self.initialize_time(self.control_freq)
        if self.has_renderer and self.viewer is None:
            self.viewer = MjViewer(self.sim)
            self.viewer.viewer.vopt.geomgroup[0] = (
                1 if self.render_collision_mesh else 0)
            self.viewer.viewer.vopt.geomgroup[
                1] = 1 if self.render_visual_mesh else 0

            # hiding the overlay speeds up rendering significantly
            self.viewer.viewer._hide_overlay = True

        elif self.has_offscreen_renderer:
            render_context = MjRenderContextOffscreen(self.sim)
            render_context.vopt.geomgroup[
                0] = 1 if self.render_collision_mesh else 0
            render_context.vopt.geomgroup[
                1] = 1 if self.render_visual_mesh else 0
            self.sim.add_render_context(render_context)

        self.sim_state_initial = self.sim.get_state()
        self._get_reference()
        self.cur_time = 0
        self.timestep = 0
        self.done = False

        # necessary to refresh MjData
        self.sim.forward()

    def find_contacts(self, geoms_1, geoms_2):
        """
        Finds contact between two geom groups.

        Args:
            geoms_1: a list of geom names (string)
            geoms_2: another list of geom names (string)

        Returns:
            iterator of all contacts between @geoms_1 and @geoms_2
        """
        for contact in self.sim.data.contact[0:self.sim.data.ncon]:
            # check contact geom in geoms
            c1_in_g1 = self.sim.model.geom_id2name(contact.geom1) in geoms_1
            c2_in_g2 = self.sim.model.geom_id2name(contact.geom2) in geoms_2
            # check contact geom in geoms (flipped)
            c2_in_g1 = self.sim.model.geom_id2name(contact.geom2) in geoms_1
            c1_in_g2 = self.sim.model.geom_id2name(contact.geom1) in geoms_2
            if (c1_in_g1 and c2_in_g2) or (c1_in_g2 and c2_in_g1):
                yield contact

    def _check_contact(self):
        """Returns True if gripper is in contact with an object."""
        return False

    def _check_success(self):
        """
        Returns True if task has been completed.
        """
        return False

    def _destroy_viewer(self):
        # if there is an active viewer window, destroy it
        if self.viewer is not None:
            self.viewer.close()  # change this to viewer.finish()?
            self.viewer = None

    def close(self):
        """Do any cleanup necessary here."""
        self._destroy_viewer()

    def get_state(self):
        return self.sim.get_state()

    def set_state(self, qpos, qvel):
        assert qpos.shape == (self.sim.model.nq, ) and qvel.shape == (
            self.sim.model.nv, )
        old_state = self.sim.get_state()
        new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel,
                                         old_state.act, old_state.udd_state)
        self.sim.set_state(new_state)
        self.sim.forward()

    def renderer_on(self):
        self.has_renderer = True

    def renderer_off(self):
        self.has_renderer = False