コード例 #1
0
 def render(self, mode='human', close=False):
     if mode == 'human':
         if self.viewer is None:
             self.viewer = MjViewer(self.sim)
         self.viewer.render()
     elif mode == 'rgb_array':
         if self.rgb_viewer is None:
             self.rgb_viewer = MjRenderContextOffscreen(self.sim, 0)
         self.rgb_viewer.render(500, 500)
         # window size used for old mujoco-py:
         data = self.rgb_viewer.read_pixels(500, 500, depth=False)
         # original image is upside-down, so flip it
         return data[::-1, :, :]
コード例 #2
0
    def setrendermode(self,rendersetting):#render_flg, plt_switch, screenwidth = 500, screenhight = 500, interval = 5):
        self.rendermode["rendertype"]=rendersetting["render_flg"]
        if rendersetting["render_flg"] == True: # in-screen: ignore other mode parameters
            self.viewer = MjViewer(self.sim)
            return
        else:                  # off-screen:
            self.rendermode["W"] = rendersetting["screenwidth"]
            self.rendermode["H"] = rendersetting["screenhight"]
            self.rendermode["pltswitch"] = rendersetting["plt_switch"]
            self.viewer = MjRenderContextOffscreen(self.sim)
            if plt.get_fignums():
                plt.close()
            self.fig,self.ax = plt.subplots()

            return 
コード例 #3
0
    def _reset_internal(self):
        """Resets simulation internal configurations."""

        # create visualization screen or renderer
        if self.has_renderer and self.viewer is None:
            self.viewer = MujocoPyRenderer(self.sim)
            self.viewer.viewer.vopt.geomgroup[0] = (1 if self.render_collision_mesh else 0)
            self.viewer.viewer.vopt.geomgroup[1] = (1 if self.render_visual_mesh else 0)

            # hiding the overlay speeds up rendering significantly
            self.viewer.viewer._hide_overlay = True

            # make sure mujoco-py doesn't block rendering frames
            # (see https://github.com/StanfordVL/robosuite/issues/39)
            self.viewer.viewer._render_every_frame = True

            # Set the camera angle for viewing
            if self.render_camera is not None:
                self.viewer.set_camera(camera_id=self.sim.model.camera_name2id(self.render_camera))

        elif self.has_offscreen_renderer:
            if self.sim._render_context_offscreen is None:
                render_context = MjRenderContextOffscreen(self.sim, device_id=self.render_gpu_device_id)
                self.sim.add_render_context(render_context)
            self.sim._render_context_offscreen.vopt.geomgroup[0] = (1 if self.render_collision_mesh else 0)
            self.sim._render_context_offscreen.vopt.geomgroup[1] = (1 if self.render_visual_mesh else 0)

        # additional housekeeping
        self.sim_state_initial = self.sim.get_state()
        self._get_reference()
        self.cur_time = 0
        self.timestep = 0
        self.done = False
コード例 #4
0
    def _get_sim(self):
        if self._sim is not None:
            return self._sim

        xml = postprocess_model_xml(self._traj.config_str)
        self._depth_norm = None
        if 'sawyer' in xml:
            from hem.datasets.precompiled_models.sawyer import models
            self._sim = models[0]
            self._depth_norm = 'sawyer'
        elif 'baxter' in xml:
            from hem.datasets.precompiled_models.baxter import models
            self._sim = models[0]
        elif 'panda' in xml:
            from hem.datasets.precompiled_models.panda import models
            self._sim = models[0]
        else:
            model = load_model_from_xml(xml)
            model.vis.quality.offsamples = 8
            sim = MjSim(load_model_from_xml(xml))
            render_context = MjRenderContextOffscreen(sim)
            render_context.vopt.geomgroup[0] = 0
            render_context.vopt.geomgroup[1] = 1 
            sim.add_render_context(render_context)
            self._sim = sim

        return self._sim
コード例 #5
0
ファイル: base.py プロジェクト: YaoweiFan/peg-in-hole
    def _reset_internal(self):
        """Resets simulation internal configurations."""
        # instantiate simulation from MJCF model
        self._load_model()
        self.mjpy_model = self.model.get_model(mode="mujoco_py")
        self.sim = MjSim(self.mjpy_model)
        self.initialize_time(self.control_freq)  #设置时间步长

        # create visualization screen or renderer
        if self.has_renderer and self.viewer is None:
            self.viewer = MujocoPyRenderer(self.sim)
            self.viewer.viewer.vopt.geomgroup[0] = (
                1 if self.render_collision_mesh else 0)
            self.viewer.viewer.vopt.geomgroup[
                1] = 1 if self.render_visual_mesh else 0

            # hiding the overlay speeds up rendering significantly
            self.viewer.viewer._hide_overlay = True

        elif self.has_offscreen_renderer:
            if self.sim._render_context_offscreen is None:
                render_context = MjRenderContextOffscreen(self.sim)
                self.sim.add_render_context(render_context)
            self.sim._render_context_offscreen.vopt.geomgroup[0] = (
                1 if self.render_collision_mesh else 0)
            self.sim._render_context_offscreen.vopt.geomgroup[1] = (
                1 if self.render_visual_mesh else 0)

        # additional housekeeping
        self.sim_state_initial = self.sim.get_state()
        self._get_reference()
        self.cur_time = 0
        self.timestep = 0
        self.done = False
コード例 #6
0
ファイル: base.py プロジェクト: xuhuahaoren/robosuite
    def reset_from_xml_string(self, xml_string):
        """Reloads the environment from an XML description of the environment."""

        # if there is an active viewer window, destroy it
        self.close()

        # load model from xml
        self.mjpy_model = load_model_from_xml(xml_string)

        self.sim = MjSim(self.mjpy_model)
        self.initialize_time(self.control_freq)
        if self.has_renderer and self.viewer is None:
            self.viewer = MujocoPyRenderer(self.sim)
            self.viewer.viewer.vopt.geomgroup[0] = (
                1 if self.render_collision_mesh else 0
            )
            self.viewer.viewer.vopt.geomgroup[1] = 1 if self.render_visual_mesh else 0

            # hiding the overlay speeds up rendering significantly
            self.viewer.viewer._hide_overlay = True

        elif self.has_offscreen_renderer:
            render_context = MjRenderContextOffscreen(self.sim)
            render_context.vopt.geomgroup[0] = 1 if self.render_collision_mesh else 0
            render_context.vopt.geomgroup[1] = 1 if self.render_visual_mesh else 0
            self.sim.add_render_context(render_context)

        self.sim_state_initial = self.sim.get_state()
        self._get_reference()
        self.cur_time = 0
        self.timestep = 0
        self.done = False

        # necessary to refresh MjData
        self.sim.forward()
コード例 #7
0
def sample_trajectory(env,
                      policy,
                      max_path_length,
                      render=False,
                      render_mode=("rgb_array")):
    # TODO: get this from hw1
    # initialize env for the beginning of a new rollout
    ob = env.reset()  # HINT: should be the output of resetting the env

    global viewer
    if render and viewer is None and hasattr(env, "sim"):
        viewer = MjRenderContextOffscreen(env.sim, 0)

    # init vars
    obs, acs, rewards, next_obs, terminals, image_obs = [], [], [], [], [], []
    steps = 0
    while True:

        # render image of the simulated env
        if render:
            if "rgb_array" in render_mode:
                if hasattr(env, "sim"):
                    try:
                        image_ob = env.sim.render(camera_name="track",
                                                  height=500,
                                                  width=500)
                    except:
                        image_ob = env.sim.render(height=500, width=500)
                    image_obs.append(image_ob[::-1])
                else:
                    image_obs.append(env.render(mode=render_mode))
            if "human" in render_mode:
                env.render(mode=render_mode)
                time.sleep(env.model.opt.timestep)

        # use the most recent ob to decide what to do
        obs.append(ob)
        ac = policy.get_action(
            ob)  # HINT: query the policy's get_action function
        ac = ac[0]
        acs.append(ac)

        # take that action and record results
        ob, rew, done, _ = env.step(ac)

        # record result of taking that action
        steps += 1
        next_obs.append(ob)
        rewards.append(rew)

        # TODO end the rollout if the rollout ended
        # HINT: rollout can end due to done, or due to max_path_length
        rollout_done = done or steps >= max_path_length  # HINT: this is either 0 or 1
        terminals.append(rollout_done)

        if rollout_done:
            break

    return Path(obs, image_obs, acs, rewards, next_obs, terminals)
コード例 #8
0
def create_model(xml):
    model = load_model_from_xml(postprocess_model_xml(xml))
    model.vis.quality.offsamples = 8

    sim = MjSim(model)
    render_context = MjRenderContextOffscreen(sim)
    render_context.vopt.geomgroup[0] = 0
    render_context.vopt.geomgroup[1] = 1
    sim.add_render_context(render_context)
    return sim
コード例 #9
0
ファイル: render_env.py プロジェクト: geyang/env-wrappers
 def _get_viewer(self, mode):
     self.viewer = self._viewers.get(mode)
     if self.viewer is None:
         if mode == 'human':
             self.viewer = MjViewer(self.sim)
         else:
             self.viewer = MjRenderContextOffscreen(self.sim, -1)
         self.viewer_setup()
         self._viewers[mode] = self.viewer
     self.viewer_setup()
     return self.viewer
コード例 #10
0
    def __init__(self, model_path, frame_skip):
        if model_path.startswith("/"):
            fullpath = model_path
        else:
            fullpath = os.path.join(os.path.dirname(__file__), "assets",
                                    model_path)
        if not path.exists(fullpath):
            raise IOError("File %s does not exist" % fullpath)
        self.frame_skip = frame_skip
        self.model = mujoco_py.load_model_from_path(fullpath)
        self.sim = mujoco_py.MjSim(self.model)
        self.renderer = MjRenderContextOffscreen(self.sim,
                                                 device_id=get_gpu_id())
        self.x_dim = 84  # 84
        self.y_dim = 84  # 84
        self.data = self.sim.data
        self.viewer = None

        self.metadata = {
            'render.modes': ['human', 'rgb_array'],
            'video.frames_per_second': int(np.round(1.0 / self.dt))
        }

        self.init_qpos = self.sim.data.qpos.ravel().copy()
        self.init_qvel = self.sim.data.qvel.ravel().copy()
        observation, _reward, done, _info = self.step(np.zeros(self.model.nu))
        assert not done

        bounds = self.model.actuator_ctrlrange.copy()
        low = bounds[:, 0]
        high = bounds[:, 1]
        self.action_space = spaces.Box(low=low, high=high, dtype=np.float32)

        if isinstance(observation, dict):
            # obs_space_dict = {}
            # for key, value in observation.items():
            #     obs_space_dict[key] = spaces.Box(-np.inf, np.inf,
            #                                      shape=value.shape,
            #                                      dtype=value.dtype)
            # self.observation_space = spaces.Dict(obs_space_dict)
            self.observation_space = {}
            for key, value in observation.items():
                self.observation_space[key] = spaces.Box(-np.inf,
                                                         np.inf,
                                                         shape=value.shape,
                                                         dtype=value.dtype)
        else:
            obs_dim = observation.size
            high = np.inf * np.ones(obs_dim)
            low = -high
            self.observation_space = spaces.Box(low, high, dtype=np.float32)

        self.seed()
コード例 #11
0
    def _viewer_setup(self, rendersetting):
        self.rendermode["rendertype"] = rendersetting["render_flg"]

        if rendersetting[
                "render_flg"] == True:  # in-screen: ignore other mode parameters
            self.viewer = MjViewer(self.sim)
        else:  # off-screen:
            self.rendermode["W"] = rendersetting["screenwidth"]
            self.rendermode["H"] = rendersetting["screenhight"]
            self.rendermode["pltswitch"] = rendersetting["plt_switch"]
            self.viewer = MjRenderContextOffscreen(self.sim)
            if plt.get_fignums():
                plt.close()
            self.fig, self.ax = plt.subplots()

        # body_id = self.sim.model.body_name2id('robot0:gripper_link')
        # lookat = self.sim.data.body_xpos[body_id]
        # for idx, value in enumerate(lookat):
        #     self.viewer.cam.lookat[idx] = value
        self.viewer.cam.lookat[:] = rendersetting["lookat"]
        self.viewer.cam.distance = rendersetting["distance"]
        self.viewer.cam.azimuth = rendersetting["azimuth"]
        self.viewer.cam.elevation = rendersetting["elevation"]
コード例 #12
0
def _eval_agent(a_n, env, args, image_based=True, cuda=False):

        # load model
        env = gym.make('FetchReach-v1')
        sim = env.sim
        viewer = MjRenderContextOffscreen(sim)
        # self.viewer.cam.fixedcamid = 3
        # self.viewer.cam.type = const.CAMERA_FIXED
        viewer.cam.distance = 1.2
        viewer.cam.azimuth = 180
        viewer.cam.elevation = -25
        env.env._viewers['rgb_array'] = viewer

        model_path = './test.pt'
        loaded_model = new_actor(get_env_params(env))
        loaded_model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))

        if cuda:
            loaded_model.cuda()

        total_success_rate = []
        for _ in range(args.n_test_rollouts):
            per_success_rate = []
            observation = env.reset()
            obs = observation['observation']
            g = observation['desired_goal']
            obs_img = env.render(mode="rgb_array", height=100, width=100)
            for _ in range(env._max_episode_steps):
                with torch.no_grad():
                    if image_based:
                        o_tensor, g_tensor = _preproc_inputs_image(obs_img.copy()[np.newaxis, :], g[np.newaxis, :], cuda)
                        pi = loaded_model(o_tensor, g_tensor)
                    else:
                        input_tensor = self._preproc_inputs(obs, g)
                        pi = actor_network(input_tensor)
                    # convert the actions
                    actions = pi.detach().cpu().numpy().squeeze()
                observation_new, _, _, info = env.step(actions)
                obs = observation_new['observation']
                obs_img = env.render(mode="rgb_array", height=100, width=100)
                g = observation_new['desired_goal']
                per_success_rate.append(info['is_success'])
            total_success_rate.append(per_success_rate)
        total_success_rate = np.array(total_success_rate)
        local_success_rate = np.mean(total_success_rate[:, -1])
        print(local_success_rate)
コード例 #13
0
    def __init__(self, model_path, initial_qpos, n_actions, n_substeps):
        if model_path.startswith('/'):
            fullpath = model_path
        else:
            fullpath = os.path.join(os.path.dirname(__file__), 'assets',
                                    model_path)
        if not os.path.exists(fullpath):
            raise IOError('File {} does not exist'.format(fullpath))

        model = mujoco_py.load_model_from_path(fullpath)
        self.sim = mujoco_py.MjSim(model, nsubsteps=n_substeps)
        self.viewer = None

        self.renderer = MjRenderContextOffscreen(self.sim, device_id=1)

        self.metadata = {
            'render.modes': ['human', 'rgb_array'],
            'video.frames_per_second': int(np.round(1.0 / self.dt))
        }

        self.seed()
        self._env_setup(initial_qpos=initial_qpos)
        self.initial_state = copy.deepcopy(self.sim.get_state())

        self.goal = self._sample_goal()
        obs = self._get_obs()
        self.action_space = spaces.Box(-1.,
                                       1.,
                                       shape=(n_actions, ),
                                       dtype='float32')
        self.observation_space = spaces.Dict(
            dict(
                desired_goal=spaces.Box(-np.inf,
                                        np.inf,
                                        shape=obs['achieved_goal'].shape,
                                        dtype='float32'),
                achieved_goal=spaces.Box(-np.inf,
                                         np.inf,
                                         shape=obs['achieved_goal'].shape,
                                         dtype='float32'),
                observation=spaces.Box(-np.inf,
                                       np.inf,
                                       shape=obs['observation'].shape,
                                       dtype='float32'),
            ))
コード例 #14
0
        lightid = model.light_name2id(name)
        value = model.light_ambient[lightid]
        print(value)
        model.light_ambient[lightid] = value + 1


# model = load_model_from_path("/Users/karanchahal/projects/mujoco-py/xmls/fetch/main.xml")
# model = load_model_from_path("/Users/karanchahal/miniconda3/envs/rlkit/lib/python3.6/site-packages/gym/envs/robotics/assets/fetch/pick_and_place.xml")
# sim = MjSim(model)

env = gym.make('FetchPickAndPlace-v1')
# exit()
# env.sim = sim

sim = env.sim
viewer = MjRenderContextOffscreen(sim)
viewer.cam.fixedcamid = 3
viewer.cam.type = const.CAMERA_FIXED
env.env._viewers['rgb_array'] = viewer
im = env.render(mode="rgb_array")
plt.imshow(im)
plt.show()

modder = TextureModder(sim)
# modder = CameraModder(sim)
modder.whiten_materials()
# modder = MaterialModder(sim)

t = 1

# viewer.cam.fixedcamid = 3
コード例 #15
0
    # camera #2
    vieweroff.render(width, height, camera_id=1)  # if camera_id=None, camera is not rendered
    rgb = vieweroff.read_pixels(width, height)[0]
    bgr = np.flipud(cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))
    cv2.imshow('teste2', bgr)

    cv2.waitKey(1)


model = mujoco_py.load_model_from_path("./assets/my-model-rrr.xml")

''' ATTENTION: if you choose to use Mujoco's default viewer, you can't see the rendering of the cameras!'''

sim = MjSim(model)
vieweroff = MjRenderContextOffscreen(sim,0)

# controller and simulation params
t = 0
qpos_ref = np.array([-2, -1, 2])
qvel_ref = np.array([0, 0, 0])
kp = 1000
kv = 500

sim.model.opt.gravity[:] = np.array([0, 0, 0])  # just to make simulation easier :)

width, height = 800, 480

t_ini = time.time()

try:
コード例 #16
0
class ImiRob:
    """ imitate robot that imitates(replay) the actions extracted from other robots
        This robot simulator is of the quadruped robob which has 4 legs along with 2 hinge joints for each leg.
        This simulation is based on mujoco. All robots are derivation of OpenAI gym ant.
    """

    def __init__(self,rendersetting,frame_skip=4,objnum=4,showarm=False):
        if showarm:
            xmlsource = '/home/cml/CML/MjPushData/xmls/show_env.xml'
        else:
            xmlsource = '/home/cml/CML/MjPushData/xmls/lab_env.xml'

        self.model =  load_model_from_path(xmlsource)
        self.sim = MjSim(self.model)
        self.frame_skip = frame_skip
        self.rendermode = {}
        self.init_state = self.sim.get_state()

        self.viewer = None
        self.setrendermode(rendersetting)
        self.stopflag = False
        self.objnum = objnum
        self.showarm = showarm

   

        self.cubes = [TableObj(["free_x_"+str(i+1),"free_y_"+str(i+1)],self.model) for i in range(self.objnum)]
        self.move_dict = {
            0:[0.3,0],
            1:[-0.3,0],
            2:[0,0.3,],
            3:[0,-0.3,]
        }
  
    
    def init_cubes(self,mjqpos):
        for i in range(self.objnum):
            pos = [np.random.uniform(low=-1.0, high=1.0),np.random.uniform(low=-1.0, high=1.0)]
            # pos =[i*0.3,i*0.3]
            mjqpos = self.cubes[i].set_pos(pos,mjqpos)
        
        return mjqpos
    
    def mesh_init_cubes(self,mjqpos,gridsize=0.5):
        def mesh_random(pt_num,swing=0.5,seglow=-1.0,seghigh=1.0):
            res = []
            last_pt = seglow - swing
            for remain_pt_num in reversed(range(pt_num)):
                segrange = [last_pt+swing,seghigh - remain_pt_num*swing]
                last_pt = np.random.uniform(low=segrange[0],high=segrange[1])
                res.append(last_pt)
            return np.random.permutation(res)
        
        self.x_coord = mesh_random(self.objnum,swing=gridsize)
        self.y_coord = mesh_random(self.objnum,swing=gridsize)
        
        for i in range(self.objnum):
            pos = [self.x_coord[i],self.y_coord[i]]
            mjqpos = self.cubes[i].set_pos(pos,mjqpos)
        
        return mjqpos
    
    def check_contacts(self):
        def checkdistance(poses,scaling=0.6,eps=0.1):
            poses = sorted(poses)
            dis = np.array([np.abs(poses[i]-poses[i+1]) for i in range(len(poses)-1)])
            return np.any(dis < eps)

        cube_pos_x = [self.cubes[i].pos[0] for i in range(self.objnum)]
        cube_pos_y = [self.cubes[i].pos[1] for i in range(self.objnum)]
        return checkdistance(cube_pos_x) or checkdistance(cube_pos_y)

        





    def setrendermode(self,rendersetting):#render_flg, plt_switch, screenwidth = 500, screenhight = 500, interval = 5):
        self.rendermode["rendertype"]=rendersetting["render_flg"]
        if rendersetting["render_flg"] == True: # in-screen: ignore other mode parameters
            self.viewer = MjViewer(self.sim)
            return
        else:                  # off-screen:
            self.rendermode["W"] = rendersetting["screenwidth"]
            self.rendermode["H"] = rendersetting["screenhight"]
            self.rendermode["pltswitch"] = rendersetting["plt_switch"]
            self.viewer = MjRenderContextOffscreen(self.sim)
            if plt.get_fignums():
                plt.close()
            self.fig,self.ax = plt.subplots()

            return 
    
    
    
    def onscreenshow(self):
        self.viewer.render()
    
    def offscreenshow(self):
        self.viewer.render(self.rendermode["H"],self.rendermode["W"])
        im_src = self.viewer.read_pixels(self.rendermode["H"],self.rendermode["W"],depth=False)
        im_src = np.flip(im_src)
        im_src = np.flip(im_src, axis = 1)
        im_src = np.flip(im_src, axis = 2)


        if self.rendermode["pltswitch"]:
            self.ax.cla()
            self.ax.imshow(im_src)
            plt.pause(1e-10)

        return im_src, not plt.get_fignums()

    
    def set_state(self, qpos, qvel):
        assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
        old_state = self.sim.get_state()
        new_state = MjSimState(old_state.time, qpos, qvel,
                                         old_state.act, old_state.udd_state)
        self.sim.set_state(new_state)
        self.sim.forward()

    def reset(self):
        self.sim.set_state(self.init_state)
        self.stopflag = False

        self.qposrecrd = []

        self.init_qpos = self.sim.data.qpos.ravel().copy()
        self.init_qvel = self.sim.data.qvel.ravel().copy()
        qpos = self.init_qpos.copy()
        qvel = self.init_qvel.copy()
  
        if self.showarm:
            qpos = np.array([0.2, -90 / 180 * math.pi, 70 / 180 * math.pi,
               0, 0, 0, # robotic arm
               0, 0, 0, 0,
               0, 0, 0, 0, # two fingers
               -0.62, 0.32,
               -0.72, 0.38,
               -0.835, 0.425,
               -0.935, 0.46,]) # 4 cubes

 

        qpos = self.mesh_init_cubes(qpos)
        self.set_state(qpos, qvel)



    def setcam(self, distance=3, azimuth=0, elevation=-90, lookat=[-0.2,0,0], hideoverlay=True,trackbodyid=-1):
        self.viewer.cam.trackbodyid = trackbodyid
        self.viewer.cam.distance=distance
        self.viewer.cam.azimuth=azimuth
        self.viewer.cam.elevation=elevation
        self.viewer.cam.lookat[:]=lookat
        self.viewer._hide_overlay=hideoverlay

    def moverobot(self):
        self.init_qpos = self.sim.data.qpos.ravel().copy()
        self.init_qvel = self.sim.data.qvel.ravel().copy()
        qpos = self.init_qpos.copy()
        qvel = self.init_qvel.copy()

        qpos=self.cubes[0].move_cube([0.2,0],qpos)


        self.set_state(qpos, qvel)
    
    def random_move_cube(self):
        move_fashion, move_cube = np.random.randint(4,size=2)
        self.init_qpos = self.sim.data.qpos.ravel().copy()
        self.init_qvel = self.sim.data.qvel.ravel().copy()
        qpos = self.init_qpos.copy()
        qvel = self.init_qvel.copy()
        qpos=self.cubes[move_cube].move_cube(self.move_dict[move_fashion],qpos)
        self.set_state(qpos, qvel)

        return move_cube,move_fashion

    def save_video(self,savepath):
        video_data = np.zeros((2,self.rendermode["H"],self.rendermode["W"],3)).astype(np.uint8)
        im1,_ = self.offscreenshow()
        movecube,movefashion = self.random_move_cube()
        im2,_ = self.offscreenshow()
        video_data[0,...] = im1
        video_data[1,...] = im2

        np.savez_compressed(savepath,video=video_data,cubeindx=movecube,actionindx=movefashion)

    def __del__(self):
        del self.viewer
        del self.model
        del self.sim
コード例 #17
0
class PendulumWithGoals(gym.Env):
    metadata = {
        'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 30
    }

    def __init__(self, goal_reaching_thresholds=np.array([0.075, 0.075, 0.75]),
                 goal_not_reached_penalty=-1, goal_reached_reward=0, terminate_on_goal_reaching=True,
                 time_limit=1000, frameskip=1, random_goals_instead_of_standing_goal=False,
                 polar_coordinates: bool=False):
        super().__init__()
        dir = os.path.dirname(__file__)
        model = load_model_from_path(dir + "/pendulum_with_goals.xml")

        self.sim = MjSim(model)
        self.viewer = None
        self.rgb_viewer = None

        self.frameskip = frameskip
        self.goal = None
        self.goal_reaching_thresholds = goal_reaching_thresholds
        self.goal_not_reached_penalty = goal_not_reached_penalty
        self.goal_reached_reward = goal_reached_reward
        self.terminate_on_goal_reaching = terminate_on_goal_reaching
        self.time_limit = time_limit
        self.current_episode_steps_counter = 0
        self.random_goals_instead_of_standing_goal = random_goals_instead_of_standing_goal
        self.polar_coordinates = polar_coordinates

        # spaces definition
        self.action_space = spaces.Box(low=-self.sim.model.actuator_ctrlrange[:, 1],
                                       high=self.sim.model.actuator_ctrlrange[:, 1],
                                       dtype=np.float32)
        if self.polar_coordinates:
            self.observation_space = spaces.Dict({
                "observation": spaces.Box(low=np.array([-np.pi, -15]),
                                          high=np.array([np.pi, 15]),
                                          dtype=np.float32),
                "desired_goal": spaces.Box(low=np.array([-np.pi, -15]),
                                           high=np.array([np.pi, 15]),
                                           dtype=np.float32),
                "achieved_goal": spaces.Box(low=np.array([-np.pi, -15]),
                                            high=np.array([np.pi, 15]),
                                            dtype=np.float32)
            })
        else:
            self.observation_space = spaces.Dict({
                "observation": spaces.Box(low=np.array([-1, -1, -15]),
                                          high=np.array([1, 1, 15]),
                                          dtype=np.float32),
                "desired_goal": spaces.Box(low=np.array([-1, -1, -15]),
                                           high=np.array([1, 1, 15]),
                                           dtype=np.float32),
                "achieved_goal": spaces.Box(low=np.array([-1, -1, -15]),
                                            high=np.array([1, 1, 15]),
                                            dtype=np.float32)
            })

        self.spec = EnvSpec('PendulumWithGoals-v0')
        self.spec.reward_threshold = self.goal_not_reached_penalty * self.time_limit

        self.reset()

    def _goal_reached(self):
        observation = self._get_obs()
        if np.any(np.abs(observation['achieved_goal'] - observation['desired_goal']) > self.goal_reaching_thresholds):
            return False
        else:
            return True

    def _terminate(self):
        if (self._goal_reached() and self.terminate_on_goal_reaching) or \
                        self.current_episode_steps_counter >= self.time_limit:
            return True
        else:
            return False

    def _reward(self):
        if self._goal_reached():
            return self.goal_reached_reward
        else:
            return self.goal_not_reached_penalty

    def step(self, action):
        self.sim.data.ctrl[:] = action
        for _ in range(self.frameskip):
            self.sim.step()

        self.current_episode_steps_counter += 1

        state = self._get_obs()

        # visualize the angular velocities
        state_velocity = np.copy(state['observation'][-1] / 20)
        goal_velocity = self.goal[-1] / 20
        self.sim.model.site_size[2] = np.array([0.01, 0.01, state_velocity])
        self.sim.data.mocap_pos[2] = np.array([0.85, 0, 0.75 + state_velocity])
        self.sim.model.site_size[3] = np.array([0.01, 0.01, goal_velocity])
        self.sim.data.mocap_pos[3] = np.array([1.15, 0, 0.75 + goal_velocity])

        return state, self._reward(), self._terminate(), {}

    def _get_obs(self):

        """
        y

        ^
        |____
        |   /
        |  /
        |~/
        |/
        --------> x

        """

        # observation
        angle = self.sim.data.qpos
        angular_velocity = self.sim.data.qvel
        if self.polar_coordinates:
            observation = np.concatenate([angle - np.pi, angular_velocity])
        else:
            x = np.sin(angle)
            y = np.cos(angle)  # qpos is the angle relative to a standing pole
            observation = np.concatenate([x, y, angular_velocity])

        return {
            "observation": observation,
            "desired_goal": self.goal,
            "achieved_goal": observation
        }

    def reset(self):
        self.current_episode_steps_counter = 0

        # set initial state
        angle = np.random.uniform(np.pi / 4, 7 * np.pi / 4)
        angular_velocity = np.random.uniform(-0.05, 0.05)
        self.sim.data.qpos[0] = angle
        self.sim.data.qvel[0] = angular_velocity
        self.sim.step()

        # goal
        if self.random_goals_instead_of_standing_goal:
            angle_target = np.random.uniform(-np.pi / 8, np.pi / 8)
            angular_velocity_target = np.random.uniform(-0.2, 0.2)
        else:
            angle_target = 0
            angular_velocity_target = 0

        # convert target values to goal
        x_target = np.sin(angle_target)
        y_target = np.cos(angle_target)
        if self.polar_coordinates:
            self.goal = np.array([angle_target - np.pi, angular_velocity_target])
        else:
            self.goal = np.array([x_target, y_target, angular_velocity_target])

        # visualize the goal
        self.sim.data.mocap_pos[0] = [x_target, 0, y_target]

        return self._get_obs()

    def render(self, mode='human', close=False):
        if mode == 'human':
            if self.viewer is None:
                self.viewer = MjViewer(self.sim)
            self.viewer.render()
        elif mode == 'rgb_array':
            if self.rgb_viewer is None:
                self.rgb_viewer = MjRenderContextOffscreen(self.sim, 0)
            self.rgb_viewer.render(500, 500)
            # window size used for old mujoco-py:
            data = self.rgb_viewer.read_pixels(500, 500, depth=False)
            # original image is upside-down, so flip it
            return data[::-1, :, :]
コード例 #18
0
        has_offscreen_renderer=False,  # not needed since not using pixel obs
        has_renderer=False,  # make sure we can render to the screen
        reward_shaping=True,  # use dense rewards
        control_freq=
        30,  # control should happen fast enough so that simulation looks smooth
    )
    world.mode = 'human'
    world.reset()
    print(dir(world.mujoco_arena))
    print(world.mujoco_arena.table_full_size)
    print(world.mujoco_arena.bin_abs)

    print(type(world.sim))
    sim = world.sim
    print(dir(sim))
    viewer = MjRenderContextOffscreen(sim, 0)
    print(dir(viewer.scn))
    set_camera_birdview(viewer)
    width, height = 516, 386
    viewer.render(width, height)
    image = np.asarray(viewer.read_pixels(width, height, depth=False)[:, :, :],
                       dtype=np.uint8)
    depth = np.asarray((viewer.read_pixels(width, height, depth=True)[1]))
    # find center depth, this is a hack
    cdepth = depth[height // 2, width // 2]
    print(cdepth)
    depth[depth > cdepth] = cdepth
    seg = depth != cdepth
    seg[:, :padding], seg[:, -padding:] = False, False
    seg[:t_padding, :], seg[-t_padding:, :] = False, False
    cv2.imwrite('test_dataset/seg.png', seg * 255)
コード例 #19
0
    model_path = './test.pt'
    model = torch.load(model_path, map_location=lambda storage, loc: storage)
    # create the environment
    env = gym.make(args.env_name)
    # get the env param
    observation = env.reset()
    # get the environment params
    env_params = {'obs': observation['observation'].shape[0], 
                  'goal': observation['desired_goal'].shape[0], 
                  'action': env.action_space.shape[0], 
                  'action_max': env.action_space.high[0],
                  }

    if image_based:
        sim = env.sim
        viewer = MjRenderContextOffscreen(sim)
        viewer.cam.fixedcamid = 3
        viewer.cam.type = const.CAMERA_FIXED
        env.env._viewers['rgb_array'] = viewer

    # create the actor network
    actor_network = new_actor(env_params)
    actor_network.load_state_dict(model)
    # actor_network.eval()
    _eval_agent(actor_network, env, args)
    exit()
    for i in range(args.demo_length):
        observation = env.reset()
        # start to do the demo
        obs = observation['observation']
        g = observation['desired_goal']
コード例 #20
0
class FetchPushEnv(FetchEnv):
    def __init__(self, rendersetting, objnum=4, reward_type='sparse'):
        initial_qpos = {
            'robot0:slide0': 0.405,
            'robot0:slide1': 0.48,
            'robot0:slide2': 0.0,
        }

        FetchEnv.__init__(self,
                          MODEL_XML_PATH,
                          has_object=True,
                          block_gripper=True,
                          n_substeps=20,
                          gripper_extra_height=0.0,
                          target_in_the_air=False,
                          target_offset=0.0,
                          obj_range=0.15,
                          target_range=0.15,
                          distance_threshold=0.05,
                          initial_qpos=initial_qpos,
                          reward_type=reward_type)
        utils.EzPickle.__init__(self)

        self.rendermode = {}
        self._viewer_setup(rendersetting)
        self.objnum = objnum

        self.objs = [
            TableObj("object{}:joint".format(i), self.model)
            for i in range(self.objnum)
        ]

        self.movedict = {
            "dimidx": [
                np.array([1.0, 0.0]),  #x - pos
                np.array([-1.0, 0.0]),  #x - neg
                np.array([0.0, 1.0]),  #y - pos,
                np.array([0.0, -1.0])
            ],  #y - neg 
        }

    def _viewer_setup(self, rendersetting):
        self.rendermode["rendertype"] = rendersetting["render_flg"]

        if rendersetting[
                "render_flg"] == True:  # in-screen: ignore other mode parameters
            self.viewer = MjViewer(self.sim)
        else:  # off-screen:
            self.rendermode["W"] = rendersetting["screenwidth"]
            self.rendermode["H"] = rendersetting["screenhight"]
            self.rendermode["pltswitch"] = rendersetting["plt_switch"]
            self.viewer = MjRenderContextOffscreen(self.sim)
            if plt.get_fignums():
                plt.close()
            self.fig, self.ax = plt.subplots()

        # body_id = self.sim.model.body_name2id('robot0:gripper_link')
        # lookat = self.sim.data.body_xpos[body_id]
        # for idx, value in enumerate(lookat):
        #     self.viewer.cam.lookat[idx] = value
        self.viewer.cam.lookat[:] = rendersetting["lookat"]
        self.viewer.cam.distance = rendersetting["distance"]
        self.viewer.cam.azimuth = rendersetting["azimuth"]
        self.viewer.cam.elevation = rendersetting["elevation"]

    def offscreenrender(self):
        self.viewer.render(self.rendermode["H"], self.rendermode["W"])
        im_src = self.viewer.read_pixels(self.rendermode["H"],
                                         self.rendermode["W"],
                                         depth=False)
        im_src = np.flip(im_src)
        im_src = np.flip(im_src, axis=1)
        im_src = np.flip(im_src, axis=2)

        if self.rendermode["pltswitch"]:
            self.ax.cla()
            self.ax.imshow(im_src)
            plt.pause(1e-10)

        return im_src, not plt.get_fignums()

    def move_gripper(self, newpos):
        # Move end effector into position.
        gripper_target = np.array(newpos)
        gripper_rotation = np.array([1., 0., 1., 0.])
        self.sim.data.set_mocap_pos('robot0:mocap', gripper_target)
        self.sim.data.set_mocap_quat('robot0:mocap', gripper_rotation)
        for _ in range(10):
            self.sim.step()
            self.viewer.render()
            # self.offscreenrender()

    def mesh_init_cubes(self, gridsize=0.4):
        def mesh_random(pt_num, swing=0.375, seglow=-0.7, seghigh=0.8):
            res = []
            last_pt = seglow - swing
            for remain_pt_num in reversed(range(pt_num)):
                segrange = [last_pt + swing, seghigh - remain_pt_num * swing]
                last_pt = np.random.uniform(low=segrange[0], high=segrange[1])
                res.append(last_pt)
            return np.random.permutation(res)

        self.x_coord = mesh_random(self.objnum, swing=gridsize)
        self.y_coord = mesh_random(self.objnum, swing=gridsize)

        for i in range(self.objnum):
            pos = [self.x_coord[i], self.y_coord[i]]
            self.objs[i].set_pos(pos, self.sim)
        self.sim.forward()

    def random_push(self, savepath):
        video_data = np.zeros((2, self.rendermode["H"], self.rendermode["W"],
                               3)).astype(np.uint8)
        move_fashion, move_cube = np.random.randint(4, size=2)

        im1, _ = self.offscreenrender()
        self.push_obj(move_fashion, move_cube)
        im2, _ = self.offscreenrender()

        video_data[0, ...] = im1
        video_data[1, ...] = im2
        np.savez_compressed(savepath,
                            video=video_data,
                            cubeindx=move_cube,
                            actionindx=move_fashion)

    def push_obj(self, move_fashion, move_cube):

        objpos = np.array(self.objs[move_cube].get_pos())
        dimidx = self.movedict["dimidx"][move_fashion]

        objpos += np.append(dimidx * 0.06, [0.4])
        # objpos[1]+= 0.07
        # objpos[2]+= 0.4
        self.move_gripper(objpos)

        objpos[2] -= 0.38
        self.move_gripper(objpos)

        for _ in range(10):
            self._set_action(np.append(dimidx * -0.4, [0.0, 0.0]))
            # self._set_action(np.array([0.0,-0.4,0,0]))
            self.sim.step()
            # self.viewer.render()
            # self.offscreenrender()

        for _ in range(10):
            self._set_action(np.append(dimidx * 0.1, [0.0, 0.0]))
            # self._set_action(np.array([0.0,0.1,0,0]))
            self.sim.step()
            self.viewer.render()
            # self.offscreenrender()

        objpos[2] += 0.7
        self.move_gripper(objpos)

    def reset(self):
        self._reset_sim()
        self.mesh_init_cubes()
        self.move_gripper([1.2, 0.75, 0.8])
コード例 #21
0
    def build(self):
        ''' Build a world, including generating XML and moving objects '''
        # Read in the base XML (contains robot, camera, floor, etc)
        self.robot_base_path = os.path.join(BASE_DIR, self.robot_base)
        with open(self.robot_base_path) as f:
            self.robot_base_xml = f.read()
        self.xml = xmltodict.parse(
            self.robot_base_xml)  # Nested OrderedDict objects

        # Convenience accessor for xml dictionary
        worldbody = self.xml['mujoco']['worldbody']

        # Move robot position to starting position
        worldbody['body']['@pos'] = convert(np.r_[self.robot_xy,
                                                  self.robot.z_height])
        worldbody['body']['@quat'] = convert(rot2quat(self.robot_rot))

        # We need this because xmltodict skips over single-item lists in the tree
        worldbody['body'] = [worldbody['body']]
        if 'geom' in worldbody:
            worldbody['geom'] = [worldbody['geom']]
        else:
            worldbody['geom'] = []

        # Add equality section if missing
        if 'equality' not in self.xml['mujoco']:
            self.xml['mujoco']['equality'] = OrderedDict()
        equality = self.xml['mujoco']['equality']
        if 'weld' not in equality:
            equality['weld'] = []

        # Add asset section if missing
        if 'asset' not in self.xml['mujoco']:
            # old default rgb1: ".4 .5 .6"
            # old default rgb2: "0 0 0"
            # light pink: "1 0.44 .81"
            # light blue: "0.004 0.804 .996"
            # light purple: ".676 .547 .996"
            # med blue: "0.527 0.582 0.906"
            # indigo: "0.293 0 0.508"
            asset = xmltodict.parse('''
                <asset>
                    <texture type="skybox" builtin="gradient" rgb1="0.527 0.582 0.906" rgb2="0.1 0.1 0.35"
                        width="800" height="800" markrgb="1 1 1" mark="random" random="0.001"/>
                    <texture name="texplane" builtin="checker" height="100" width="100"
                        rgb1="0.7 0.7 0.7" rgb2="0.8 0.8 0.8" type="2d"/>
                    <material name="MatPlane" reflectance="0.1" shininess="0.1" specular="0.1"
                        texrepeat="10 10" texture="texplane"/>
                </asset>
                ''')
            self.xml['mujoco']['asset'] = asset['asset']

        # Add light to the XML dictionary
        light = xmltodict.parse('''<b>
            <light cutoff="100" diffuse="1 1 1" dir="0 0 -1" directional="true"
                exponent="1" pos="0 0 0.5" specular="0 0 0" castshadow="false"/>
            </b>''')
        worldbody['light'] = light['b']['light']

        # Add floor to the XML dictionary if missing
        if not any(g.get('@name') == 'floor' for g in worldbody['geom']):
            floor = xmltodict.parse('''
                <geom name="floor" type="plane" condim="6"/>
                ''')
            worldbody['geom'].append(floor['geom'])

        # Make sure floor renders the same for every world
        for g in worldbody['geom']:
            if g['@name'] == 'floor':
                g.update({
                    '@size': convert(self.floor_size),
                    '@rgba': '1 1 1 1',
                    '@material': 'MatPlane'
                })

        # Add cameras to the XML dictionary
        cameras = xmltodict.parse('''<b>
            <camera name="fixednear" pos="0 -2 2" zaxis="0 -1 1"/>
            <camera name="fixedfar" pos="0 -5 5" zaxis="0 -1 1"/>
            <camera name="fixedtop" pos="0 0 5" zaxis="0 0 1"/>
            </b>''')
        worldbody['camera'] = cameras['b']['camera']

        # Build and add a tracking camera (logic needed to ensure orientation correct)
        theta = self.robot_rot
        xyaxes = dict(x1=np.cos(theta),
                      x2=-np.sin(theta),
                      x3=0,
                      y1=np.sin(theta),
                      y2=np.cos(theta),
                      y3=1)
        pos = dict(xp=0 * np.cos(theta) + (-2) * np.sin(theta),
                   yp=0 * (-np.sin(theta)) + (-2) * np.cos(theta),
                   zp=2)
        track_camera = xmltodict.parse('''<b>
            <camera name="track" mode="track" pos="{xp} {yp} {zp}" xyaxes="{x1} {x2} {x3} {y1} {y2} {y3}"/>
            </b>'''.format(**pos, **xyaxes))

        # support multiple cameras on 'body'
        if not isinstance(worldbody['body'][0]['camera'], list):
            worldbody['body'][0]['camera'] = [worldbody['body'][0]['camera']]
        worldbody['body'][0]['camera'].append(track_camera['b']['camera'])

        # Add objects to the XML dictionary
        for name, object in self.objects.items():
            assert object['name'] == name, f'Inconsistent {name} {object}'
            object = object.copy()  # don't modify original object
            object['quat'] = rot2quat(object['rot'])
            if name == 'box':
                dim = object['size'][0]
                object['height'] = object['size'][-1]
                object['width'] = dim / 2
                object['x'] = dim
                object['y'] = dim
                body = xmltodict.parse(
                    '''
                    <body name="{name}" pos="{pos}" quat="{quat}">
                        <freejoint name="{name}"/>
                        <geom name="{name}" type="{type}" size="{size}" density="{density}"
                            rgba="{rgba}" group="{group}"/>
                        <geom name="col1" type="{type}" size="{width} {width} {height}" density="{density}"
                            rgba="{rgba}" group="{group}" pos="{x} {y} 0"/>
                        <geom name="col2" type="{type}" size="{width} {width} {height}" density="{density}"
                            rgba="{rgba}" group="{group}" pos="-{x} {y} 0"/>
                        <geom name="col3" type="{type}" size="{width} {width} {height}" density="{density}"
                            rgba="{rgba}" group="{group}" pos="{x} -{y} 0"/>
                        <geom name="col4" type="{type}" size="{width} {width} {height}" density="{density}"
                            rgba="{rgba}" group="{group}" pos="-{x} -{y} 0"/>
                    </body>
                '''.format(**{k: convert(v)
                              for k, v in object.items()}))
            else:
                body = xmltodict.parse(
                    '''
                    <body name="{name}" pos="{pos}" quat="{quat}">
                        <freejoint name="{name}"/>
                        <geom name="{name}" type="{type}" size="{size}" density="{density}"
                            rgba="{rgba}" group="{group}"/>
                    </body>
                '''.format(**{k: convert(v)
                              for k, v in object.items()}))
            # Append new body to world, making it a list optionally
            # Add the object to the world
            worldbody['body'].append(body['body'])
        # Add mocaps to the XML dictionary
        for name, mocap in self.mocaps.items():
            # Mocap names are suffixed with 'mocap'
            assert mocap['name'] == name, f'Inconsistent {name} {object}'
            assert name.replace(
                'mocap', 'obj') in self.objects, f'missing object for {name}'
            # Add the object to the world
            mocap = mocap.copy()  # don't modify original object
            mocap['quat'] = rot2quat(mocap['rot'])
            body = xmltodict.parse('''
                <body name="{name}" mocap="true">
                    <geom name="{name}" type="{type}" size="{size}" rgba="{rgba}"
                        pos="{pos}" quat="{quat}" contype="0" conaffinity="0" group="{group}"/>
                </body>
            '''.format(**{k: convert(v)
                          for k, v in mocap.items()}))
            worldbody['body'].append(body['body'])
            # Add weld to equality list
            mocap['body1'] = name
            mocap['body2'] = name.replace('mocap', 'obj')
            weld = xmltodict.parse('''
                <weld name="{name}" body1="{body1}" body2="{body2}" solref=".02 1.5"/>
            '''.format(**{k: convert(v)
                          for k, v in mocap.items()}))
            equality['weld'].append(weld['weld'])
        # Add geoms to XML dictionary
        for name, geom in self.geoms.items():
            assert geom['name'] == name, f'Inconsistent {name} {geom}'
            geom = geom.copy()  # don't modify original object
            geom['quat'] = rot2quat(geom['rot'])
            geom['contype'] = geom.get('contype', 1)
            geom['conaffinity'] = geom.get('conaffinity', 1)
            body = xmltodict.parse('''
                <body name="{name}" pos="{pos}" quat="{quat}">
                    <geom name="{name}" type="{type}" size="{size}" rgba="{rgba}" group="{group}"
                        contype="{contype}" conaffinity="{conaffinity}"/>
                </body>
            '''.format(**{k: convert(v)
                          for k, v in geom.items()}))
            # Append new body to world, making it a list optionally
            # Add the object to the world
            worldbody['body'].append(body['body'])

        # Instantiate simulator
        # print(xmltodict.unparse(self.xml, pretty=True))
        self.xml_string = xmltodict.unparse(self.xml)
        self.model = load_model_from_xml(self.xml_string)
        self.sim = MjSim(self.model)

        # Add render contexts to newly created sim
        if self.render_context is None and self.observe_vision:
            render_context = MjRenderContextOffscreen(
                self.sim, device_id=self.render_device_id, quiet=True)
            render_context.vopt.geomgroup[:] = 1
            self.render_context = render_context

        if self.render_context is not None:
            self.render_context.update_sim(self.sim)

        # Recompute simulation intrinsics from new position
        self.sim.forward()
コード例 #22
0
    def __init__(self, args, env, env_params):
        self.args = args
        self.env = env
        self.env_params = env_params
        sim = self.env.sim
        self.viewer = MjRenderContextOffscreen(sim)
        # self.viewer.cam.fixedcamid = 3
        # self.viewer.cam.type = const.CAMERA_FIXED
        self.critic_loss = []
        self.actor_loss = []
        self.viewer.cam.distance = 1.2
        self.viewer.cam.azimuth = 180
        self.viewer.cam.elevation = -25
        env.env._viewers['rgb_array'] = self.viewer

        self.env_params = env_params
        self.image_based = True if args.image else False
        print("Training image based RL ? : {}".format(self.image_based))
        # create the network
        if not self.image_based:
            self.actor_network = actor(env_params)
        else:
            self.actor_network = new_actor(env_params)
            #self.actor_network = resnet_actor(env_params)
        self.critic_network = critic(env_params)

        # sync the networks across the cpus
        sync_networks(self.actor_network)
        sync_networks(self.critic_network)
        # build up the target network
        if not self.image_based:
            self.actor_target_network = actor(env_params)
        else:
            #self.actor_target_network = resnet_actor(env_params)
            self.actor_target_network = new_actor(env_params)

        self.critic_target_network = critic(env_params)
        # load the weights into the target networks
        self.actor_target_network.load_state_dict(
            self.actor_network.state_dict())
        self.critic_target_network.load_state_dict(
            self.critic_network.state_dict())
        # if use gpu
        if self.args.cuda:
            print("use the GPU")
            self.actor_network.cuda(MPI.COMM_WORLD.Get_rank())
            self.critic_network.cuda(MPI.COMM_WORLD.Get_rank())
            self.actor_target_network.cuda(MPI.COMM_WORLD.Get_rank())
            self.critic_target_network.cuda(MPI.COMM_WORLD.Get_rank())

        # create the optimizer
        self.actor_optim = torch.optim.Adam(self.actor_network.parameters(),
                                            lr=self.args.lr_actor)
        self.critic_optim = torch.optim.Adam(self.critic_network.parameters(),
                                             lr=self.args.lr_critic)
        # her sampler
        self.her_module = her_sampler(self.args.replay_strategy,
                                      self.args.replay_k,
                                      self.env.compute_reward,
                                      self.image_based)
        # create the replay buffer
        self.buffer = replay_buffer(self.env_params, self.args.buffer_size,
                                    self.her_module.sample_her_transitions,
                                    self.image_based)
        # create the normalizer
        self.o_norm = normalizer(size=env_params['obs'],
                                 default_clip_range=self.args.clip_range)
        self.g_norm = normalizer(size=env_params['goal'],
                                 default_clip_range=self.args.clip_range)
        # create the dict for store the model
        if MPI.COMM_WORLD.Get_rank() == 0:
            if not os.path.exists(self.args.save_dir):
                os.mkdir(self.args.save_dir)
            # path to save the model
            self.model_path = os.path.join(self.args.save_dir,
                                           self.args.env_name)
            if not os.path.exists(self.model_path):
                os.mkdir(self.model_path)