예제 #1
0
def make_model(no_walls=False):
    xml_string = common.read_model('ant.xml')
    parser = etree.XMLParser(remove_blank_text=True)
    mjcf = etree.XML(xml_string, parser)

    if no_walls:
        for wall in _WALLS:
            wall_geom = xml_tools.find_element(mjcf, 'geom', wall)
            wall_geom.getparent().remove(wall_geom)

    return etree.tostring(mjcf, pretty_print=True)
예제 #2
0
def make_model(n_boxes):
    """Returns a tuple containing the model XML string and a dict of assets."""
    xml_string = common.read_model('stacker.xml')
    parser = etree.XMLParser(remove_blank_text=True)
    mjcf = etree.XML(xml_string, parser)

    # Remove unused boxes
    for b in range(n_boxes, 4):
        box = xml_tools.find_element(mjcf, 'body', 'box' + str(b))
        box.getparent().remove(box)

    return etree.tostring(mjcf, pretty_print=True), common.ASSETS
예제 #3
0
def get_model_and_assets(hand_length):
    """Returns a tuple containing the model XML string and a dict of assets."""
    xml_string = common.read_model('reacher.xml')
    mjcf = etree.fromstring(xml_string)
    arm = mjcf.find('./worldbody/body/')
    arm.set('fromto', '0 0 0 %g 0 0' % (hand_length + 0.02))
    body = arm.getnext().getnext()
    body.set('pos', "%g 0 0" % (hand_length + 0.02))
    hand = mjcf.find('./worldbody/body/body/')
    hand.set('fromto', "0 0 0 %g 0 0" % hand_length)
    finger = hand.getnext().getnext()
    finger.set('pos', "%g 0 0" % (hand_length + 0.02))
    #return common.read_model('reacher.xml'), common.ASSETS
    return etree.tostring(mjcf, pretty_print=True), common.ASSETS
예제 #4
0
def get_model_and_assets(length):
    """Returns a tuple containing the model XML string and a dict of assets."""
    xml_string = common.read_model('pendulum.xml')
    mjcf = etree.fromstring(xml_string)
    body = mjcf.find('./worldbody/')
    for _ in range(4):
        body = body.getnext()
    body.set('pos', "0 0 %g" % (length + 0.1))
    pole = mjcf.find('./worldbody/body/')
    pole = pole.getnext().getnext()
    pole.set('fromto', "0 0 0 0 0 %g" % length)
    mass = pole.getnext()
    mass.set('pos', "0 0 %g" % length)
    #return common.read_model('pendulum.xml'), common.ASSETS
    return etree.tostring(mjcf, pretty_print=True), common.ASSETS
예제 #5
0
def _make_model(n_bodies):
    """Generates an xml string defining a swimmer with `n_bodies` bodies."""
    if n_bodies < 3:
        raise ValueError(
            'At least 3 bodies required. Received {}'.format(n_bodies))
    mjcf = etree.fromstring(common.read_model('swimmer.xml'))
    head_body = mjcf.find('./worldbody/body')
    actuator = etree.SubElement(mjcf, 'actuator')
    sensor = etree.SubElement(mjcf, 'sensor')

    parent = head_body
    for body_index in range(n_bodies - 1):
        site_name = 'site_{}'.format(body_index)
        child = _make_body(body_index=body_index)
        child.append(etree.Element('site', name=site_name))
        joint_name = 'joint_{}'.format(body_index)
        joint_limit = 360.0 / n_bodies
        joint_range = '{} {}'.format(-joint_limit, joint_limit)
        child.append(
            etree.Element('joint', {
                'name': joint_name,
                'range': joint_range
            }))
        #motor_name = 'motor_{}'.format(body_index)
        position_name = 'position_{}'.format(body_index)
        #actuator.append(etree.Element('motor', name=motor_name, joint=joint_name))
        actuator.append(
            etree.Element('position', name=position_name, joint=joint_name))
        velocimeter_name = 'velocimeter_{}'.format(body_index)
        sensor.append(
            etree.Element('velocimeter', name=velocimeter_name,
                          site=site_name))
        gyro_name = 'gyro_{}'.format(body_index)
        sensor.append(etree.Element('gyro', name=gyro_name, site=site_name))
        parent.append(child)
        parent = child

    # Move tracking cameras further away from the swimmer according to its length.
    cameras = mjcf.findall('./worldbody/body/camera')
    scale = n_bodies / 6.0
    for cam in cameras:
        if cam.get('mode') == 'trackcom':
            old_pos = cam.get('pos').split(' ')
            new_pos = ' '.join([str(float(dim) * scale) for dim in old_pos])
            cam.set('pos', new_pos)

    return etree.tostring(mjcf, pretty_print=True)
예제 #6
0
def make_model(floor_size=None,
               terrain=False,
               rangefinders=False,
               walls_and_ball=False,
               goal=False):
    """Returns the model XML string."""
    xml_string = common.read_model('quadruped.xml')
    parser = etree.XMLParser(remove_blank_text=True)
    mjcf = etree.XML(xml_string, parser)

    # Set floor size.
    if floor_size is not None:
        floor_geom = mjcf.find('.//geom[@name={!r}]'.format('floor'))
        floor_geom.attrib['size'] = '{} {} .5'.format(floor_size, floor_size)

    # Remove walls, ball and target.
    if not walls_and_ball:
        for wall in _WALLS:
            wall_geom = xml_tools.find_element(mjcf, 'geom', wall)
            wall_geom.getparent().remove(wall_geom)

        # Remove ball.
        ball_body = xml_tools.find_element(mjcf, 'body', 'ball')
        ball_body.getparent().remove(ball_body)

        # Remove target.
        target_site = xml_tools.find_element(mjcf, 'site', 'target')
        target_site.getparent().remove(target_site)

    if not goal:
        pass

    # Remove terrain.
    if not terrain:
        terrain_geom = xml_tools.find_element(mjcf, 'geom', 'terrain')
        terrain_geom.getparent().remove(terrain_geom)

    # Remove rangefinders if they're not used, as range computations can be
    # expensive, especially in a scene with heightfields.
    if not rangefinders:
        rangefinder_sensors = mjcf.findall('.//rangefinder')
        for rf in rangefinder_sensors:
            rf.getparent().remove(rf)

    return etree.tostring(mjcf, pretty_print=True)
예제 #7
0
def get_model_and_assets(body_length):
  """Returns a tuple containing the model XML string and a dict of assets."""
  #suite_dir = os.path.dirname(os.path.dirname(__file__))
  #xml_string = resources.GetResource(os.path.join(
  #     suite_dir, 'custom_suite/custom_cheetah.xml'))
  xml_string = common.read_model('cheetah.xml')
  mjcf = etree.fromstring(xml_string)
  body = mjcf.find('./worldbody/body/')
  for i in range(6):
    body = body.getnext()
  body.set('fromto', "-%g 0 0 %g 0 0"%(body_length, body_length))
  head = body.getnext()
  head.set('pos', "%g 0 .1"%(body_length+0.1))
  thigh =  head.getnext()
  thigh.set('pos', "-%g 0 0"%body_length)
  thigh = thigh.getnext()
  thigh.set('pos', "%g 0 0"%body_length) 
  return etree.tostring(mjcf, pretty_print=True), common.ASSETS
예제 #8
0
def get_model_and_assets(length):
  """Returns a tuple containing the model XML string and a dict of assets."""
  xml_string = common.read_model('finger.xml')
  mjcf = etree.fromstring(xml_string)
  proximal = mjcf.find('./worldbody/body/')
  proximal = proximal.getnext().getnext()
  proximal.set('fromto', "0 0 0 0 0 -%g"%length)
  distal = proximal.getnext()
  distal.set('pos', "0 0 -%g"%(length+0.01))
  distal = mjcf.find('./worldbody/body/body/')
  distal = distal.getnext()
  distal.set('fromto', "0 0 0 0 0 -%g"%(length-0.01))
  distal = distal.getnext()
  distal.set('fromto', "0 0 -.13 0 0 -%g"%(length-0.01+0.001))
  touchtop = distal.getnext()
  touchtop.set('pos', ".01 0 -%g"%length)
  touchbottom = touchtop.getnext()
  touchbottom.set('pos', "-.01 0 -%g"%length)
  #return common.read_model('finger.xml'), common.ASSETS
  return etree.tostring(mjcf, pretty_print=True), common.ASSETS
예제 #9
0
def get_model_and_assets(body_length):
    """Returns a tuple containing the model XML string and a dict of assets."""
    xml_string = common.read_model('walker.xml')
    mjcf = etree.fromstring(xml_string)
    torso = mjcf.find('./worldbody/body')
    torso.set('pos', "0 0 %g" % (1 + body_length))
    torso = mjcf.find('./worldbody/body/geom')
    torso.set('size', "0.07 %g" % body_length)
    #right = mjcf.find('./worldbody/body/body/body/')
    #right.set('pos', "0 0 %g"%leg_length)
    #right = right.getnext()
    #right.set('size', "0.04 %g"%leg_length)
    #right = right.getnext()
    #right.set('pos', "0.06 0 -%g"%leg_length)
    #left = mjcf.find('./worldbody/body/body').getnext()
    #left = left.getchildren()[2].getchildren()[0]
    #left.set('pos', "0 0 %g"%leg_length)
    #left =left.getnext()
    #left.set('size', "0.04 %g"%leg_length)
    #left = left.getnext()
    #left.set('pos', "0.06 0 -%g"%leg_length)
    return etree.tostring(mjcf, pretty_print=True), common.ASSETS
예제 #10
0
def make_model(use_peg, insert):
    """Returns a tuple containing the model XML string and a dict of assets."""
    xml_string = common.read_model('manipulator.xml')
    parser = etree.XMLParser(remove_blank_text=True)
    mjcf = etree.XML(xml_string, parser)

    # Select the desired prop.
    if use_peg:
        required_props = ['peg', 'target_peg']
        if insert:
            required_props += ['slot']
    else:
        required_props = ['ball', 'target_ball']
        if insert:
            required_props += ['cup']

    # Remove unused props
    for unused_prop in _ALL_PROPS.difference(required_props):
        prop = xml_tools.find_element(mjcf, 'body', unused_prop)
        prop.getparent().remove(prop)

    return etree.tostring(mjcf, pretty_print=True), common.ASSETS
예제 #11
0
def _make_model(n_poles):
  """Generates an xml string defining a cart with `n_poles` bodies."""
  xml_string = common.read_model('cartpole.xml')
  if n_poles == 1:
    return xml_string
  mjcf = etree.fromstring(xml_string)
  parent = mjcf.find('./worldbody/body/body')  # Find first pole.
  # Make chain of poles.
  for pole_index in range(2, n_poles+1):
    child = etree.Element('body', name='pole_{}'.format(pole_index),
                          pos='0 0 1', childclass='pole')
    etree.SubElement(child, 'joint', name='hinge_{}'.format(pole_index))
    etree.SubElement(child, 'geom', name='pole_{}'.format(pole_index))
    parent.append(child)
    parent = child
  # Move plane down.
  floor = mjcf.find('./worldbody/geom')
  floor.set('pos', '0 0 {}'.format(1 - n_poles - .05))
  # Move cameras back.
  cameras = mjcf.findall('./worldbody/camera')
  cameras[0].set('pos', '0 {} 1'.format(-1 - 2*n_poles))
  cameras[1].set('pos', '0 {} 2'.format(-2*n_poles))
  return etree.tostring(mjcf, pretty_print=True)
예제 #12
0
def make_maze_model(num_walls):
  xml_string = common.read_model('quadruped.xml')
  parser = etree.XMLParser(remove_blank_text=True)
  mjcf = etree.XML(xml_string, parser)

  #floor_geom = mjcf.find('.//geom[@name={!r}]'.format('floor'))
  #floor_geom.attrib['size'] = '{} {} .5'.format(5, 5)

  # Remove ball.
  ball_body = xml_tools.find_element(mjcf, 'body', 'ball')
  ball_body.getparent().remove(ball_body)

    # Remove target.
  #target_site = xml_tools.find_element(mjcf, 'site', 'target')
  #target_site.attrib['pos'] = '12 12 .05'
    
  terrain_geom = xml_tools.find_element(mjcf, 'geom', 'terrain')
  terrain_geom.getparent().remove(terrain_geom)

  rangefinder_sensors = mjcf.findall( './/rangefinder')
  for rf in rangefinder_sensors:
    rf.getparent().remove(rf)

  return etree.tostring(mjcf, pretty_print=True)
def get_model_and_assets_from_setting_kwargs(model_fname, setting_kwargs=None):
    """"Returns a tuple containing the model XML string and a dict of assets."""
    assets = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename))
          for filename in _FILENAMES}

    if setting_kwargs is None:
        return common.read_model(model_fname), assets

    # Convert XML to dicts
    model = xmltodict.parse(common.read_model(model_fname))
    materials = xmltodict.parse(assets['./common/materials.xml'])
    skybox = xmltodict.parse(assets['./common/skybox.xml'])

    # Edit lighting
    if 'light_pos' in setting_kwargs:
        assert isinstance(setting_kwargs['light_pos'], (list, tuple, np.ndarray))
        light_pos = f'{setting_kwargs["light_pos"][0]} {setting_kwargs["light_pos"][1]} {setting_kwargs["light_pos"][2]}'
        if 'light' in model['mujoco']['worldbody']:
            model['mujoco']['worldbody']['light']['@pos'] = light_pos
        elif 'light' in model['mujoco']['worldbody']['body']:
            model['mujoco']['worldbody']['body']['light']['@pos'] = light_pos
        else:
            raise NotImplementedError('model xml does not contain entity light')

    # Edit camera
    if 'cam_pos' in setting_kwargs:
        assert isinstance(setting_kwargs['cam_pos'], (list, tuple, np.ndarray))
        cam_pos = f'{setting_kwargs["cam_pos"][0]} {setting_kwargs["cam_pos"][1]} {setting_kwargs["cam_pos"][2]}'
        if 'camera' in model['mujoco']['worldbody']:
            model['mujoco']['worldbody']['camera'][0]['@pos'] = cam_pos
        elif 'camera' in model['mujoco']['worldbody']['body']:
            model['mujoco']['worldbody']['body']['camera'][0]['@pos'] = cam_pos
        else:
            raise NotImplementedError('model xml does not contain entity camera')

    # Edit distractor
    if 'distractor_pos' in setting_kwargs:
        assert isinstance(setting_kwargs['distractor_pos'], (list, tuple, np.ndarray))
        distractor_pos = f'{setting_kwargs["distractor_pos"][0]} {setting_kwargs["distractor_pos"][1]} {setting_kwargs["distractor_pos"][2]}'
        assert model['mujoco']['worldbody']['body'][-1]['@name'] == 'distractor', 'distractor must be in worldbody'
        model['mujoco']['worldbody']['body'][-1]['geom']['@pos'] = distractor_pos

    # Edit grid floor
    if 'grid_rgb1' in setting_kwargs:
        assert isinstance(setting_kwargs['grid_rgb1'], (list, tuple, np.ndarray))
        materials['mujoco']['asset']['texture']['@rgb1'] = \
            f'{setting_kwargs["grid_rgb1"][0]} {setting_kwargs["grid_rgb1"][1]} {setting_kwargs["grid_rgb1"][2]}'
    if 'grid_rgb2' in setting_kwargs:
        assert isinstance(setting_kwargs['grid_rgb2'], (list, tuple, np.ndarray))
        materials['mujoco']['asset']['texture']['@rgb2'] = \
            f'{setting_kwargs["grid_rgb2"][0]} {setting_kwargs["grid_rgb2"][1]} {setting_kwargs["grid_rgb2"][2]}'
    if 'grid_texrepeat' in setting_kwargs:
        assert isinstance(setting_kwargs['grid_texrepeat'], (list, tuple, np.ndarray))
        materials['mujoco']['asset']['material'][0]['@texrepeat'] = \
            f'{setting_kwargs["grid_texrepeat"][0]} {setting_kwargs["grid_texrepeat"][1]}'
    if 'grid_reflectance' in setting_kwargs:
        materials['mujoco']['asset']['material'][0]['@reflectance'] = \
            str(setting_kwargs["grid_reflectance"])

    # Edit self
    if 'self_rgb' in setting_kwargs:
        assert isinstance(setting_kwargs['self_rgb'], (list, tuple, np.ndarray))
        materials['mujoco']['asset']['material'][1]['@rgba'] = \
            f'{setting_kwargs["self_rgb"][0]} {setting_kwargs["self_rgb"][1]} {setting_kwargs["self_rgb"][2]} 1'

    # Edit skybox
    if 'skybox_rgb' in setting_kwargs:
        assert isinstance(setting_kwargs['skybox_rgb'], (list, tuple, np.ndarray))
        skybox['mujoco']['asset']['texture']['@rgb1'] = \
            f'{setting_kwargs["skybox_rgb"][0]} {setting_kwargs["skybox_rgb"][1]} {setting_kwargs["skybox_rgb"][2]}'
    if 'skybox_rgb2' in setting_kwargs:
        assert isinstance(setting_kwargs['skybox_rgb2'], (list, tuple, np.ndarray))
        skybox['mujoco']['asset']['texture']['@rgb2'] = \
            f'{setting_kwargs["skybox_rgb2"][0]} {setting_kwargs["skybox_rgb2"][1]} {setting_kwargs["skybox_rgb2"][2]}'
    if 'skybox_markrgb' in setting_kwargs:
        assert isinstance(setting_kwargs['skybox_markrgb'], (list, tuple, np.ndarray))
        skybox['mujoco']['asset']['texture']['@markrgb'] = \
            f'{setting_kwargs["skybox_markrgb"][0]} {setting_kwargs["skybox_markrgb"][1]} {setting_kwargs["skybox_markrgb"][2]}'

    # Convert back to XML
    model_xml = xmltodict.unparse(model)
    assets['./common/materials.xml'] = xmltodict.unparse(materials)
    assets['./common/skybox.xml'] = xmltodict.unparse(skybox)

    return model_xml, assets
예제 #14
0
def _make_model(n_poles):
  """Generates an xml string defining a cart with `n_poles` bodies."""
  xml_string = common.read_model('cartpole_2.xml')
  if n_poles == 1:
    return xml_string
def get_model_and_assets():
    """Returns a tuple containing the model XML string and a dict of assets."""
    return common.read_model('ball_in_cup_distractor.xml'), common.ASSETS
예제 #16
0
def main(xml_name):
    with open(xml_name) as xmlf:
        xml_str = xmlf.read()

    gen = DummyGen()
    gen.override_from_xml(xml_str)

    _DEFAULT_TIME_LIMIT = 10
    _CONTROL_TIMESTEP = .04
    display_stride = 1 / .04 // 24

    genesis_physics = Physics.from_xml_string(common.read_model(os.path.join(os.getcwd(), xml_name)), 
                                              common.ASSETS)

    genesis_physics.set_genesis(gen)
    genesis_task = FindTarget()
    genesis_env = control.Environment(genesis_physics, 
                                     genesis_task,
                                     control_timestep=_CONTROL_TIMESTEP,
                                     time_limit=_DEFAULT_TIME_LIMIT)
    action_spec = genesis_env.action_spec()
    observation_spec = genesis_env.observation_spec()
    observation_shape = np.array([0])

    for (name, row) in observation_spec.items():
        print (name, observation_shape, row.shape)
        if(row.shape == ()):
            observation_shape[0] += 1
            continue
        print(row.shape)
        observation_shape[0] += row.shape[0]
    observation_shape = (observation_shape[0],)
    print(action_spec)
    print(action_spec.minimum)
    agent = PPOAgent(
        states=dict(type='float', min_value=action_spec.minimum, max_value=action_spec.maximum, shape=observation_shape),
        actions=dict(type='float', min_value=action_spec.minimum, max_value=action_spec.maximum, shape=action_spec.shape),
        network=[
            dict(type='dense', size=128, activation='relu'),
            dict(type='dense', size=64, activation='relu'),
            dict(type='dense', size=16, activation='tanh')
        ],
        step_optimizer={
            "type": "adam",
            "learning_rate": 1e-4
        },
        entropy_regularization=0.01,
        batching_capacity=64,
        subsampling_fraction=0.1,
        optimization_steps=50,
        discount=0.99,
        likelihood_ratio_clipping=0.2,
        baseline_mode="states",
        baseline={
            "type":"mlp",
            "sizes": [32, 32]
        },
        baseline_optimizer={
            "type":"multi_step",
            "optimizer": {
                "type": "adam",
                "learning_rate": 1e-4
            },
            "num_steps": 5
        },
        update_mode={
            "unit": "episodes",
            "batch_size": 128,
            "frequency": 10
        },
        memory={
            "type": "latest",
            "include_next_states": False,
            "capacity": 2000
        }
    )

    time_step = genesis_env.reset()
    curtime = 0.0
    top_view = genesis_env.physics.render(480, 480, camera_id='tracking_top')
    side_view = genesis_env.physics.render(480, 480, camera_id='arm_eye')
    did_except = False
    
    NUM_EPISODES = 10000
    N_INPROG_VIDS = 4
    VID_EVERY = NUM_EPISODES // N_INPROG_VIDS

    for i in tqdm.tqdm(range(NUM_EPISODES)):
        time_step = genesis_env.reset()
        j = 0
        tot = 0
        reward = []
        while not time_step.last():
            state = observation2state(time_step.observation)
            action = agent.act(state)
            time_step = genesis_env.step(action)
            tot += time_step.reward
            reward.append(time_step.reward)
            agent.observe(reward=time_step.reward, terminal=time_step.last())
            if(j % 50 == 0 and i % 25 == 1):
                pass
                #clear_output()
                #img = plt.imshow(np.array(env.physics.render(480, 640)).reshape(480, 640, 3))
                #plt.pause(0.5)
                
            j += 1

        if i % 100 == 0:
                #tot /= j
            tqdm.tqdm.write("for episode " + str(i) +  " : " + str(tot))
            

        if (i % VID_EVERY) == 0 or i == NUM_EPISODES - 1:
            
            agent.save_model('./models/starfish_model_target')

            time_step = genesis_env.reset()
            
            vid_suffix = str(i)
            if i == NUM_EPISODES - 1:
                vid_suffix = 'final'
            vid_name = 'videos/starfish_{}.mp4'.format(vid_suffix)

            imnames = set()
            picidx = 0
            curtime = 0.0

            while not time_step.last():
                try:
                    state = observation2state(time_step.observation)
                    action = agent.act(state)
                    time_step = genesis_env.step(action)
                    savename = "/tmp/starfish_{0:04}.jpg".format(picidx)
                    picidx += 1
                    imnames.add(savename)
                    curtime += _CONTROL_TIMESTEP
                    top_view = genesis_env.physics.render(480, 480, camera_id='tracking_top')
                    side_view = genesis_env.physics.render(480, 480, camera_id='arm_eye')
                    #plt.imshow(np.concatenate((top_view, side_view), axis=1))
                    #plt.pause(0.5)
                    io.imsave(savename, np.concatenate((top_view, side_view), axis=1))
                except PhysicsError:
                    print('except')
                    did_except = True
                    break
            if os.path.isfile(vid_name):
                os.remove(vid_name)
            if not did_except:
                os.system('ffmpeg -nostats -loglevel 0 -f image2 -pattern_type sequence -i "/tmp/starfish_%4d.jpg" -qscale:v 0 {}'.format(vid_name))
            for name in imnames:
                os.remove(name)
            print("recorded video")
예제 #17
0
def _get_model_and_assets(model_filename):
    """Returns a tuple containing the model XML string and a dict of assets."""
    return common.read_model(mujoco_model_path(model_filename)), common.ASSETS
예제 #18
0
def get_model_and_assets():
    """Returns a tuple containing the model XML string and a dict of assets."""
    return common.read_model('multislider_render.xml'), common.ASSETS
예제 #19
0
def get_model_and_assets():
    curpath = os.getcwd()
    return common.read_model(curpath + '/dust.xml'), common.ASSETS
예제 #20
0
def _make_model():
  model_path = os.path.join(os.path.dirname( __file__ ), 'jaco_pos.xml')
  xml_string = common.read_model(model_path)
  mjcf = etree.fromstring(xml_string)
  return etree.tostring(mjcf, pretty_print=True)
예제 #21
0
def get_model_and_assets():
    """Returns a tuple containing the model XML string and a dict of assets."""
    return common.read_model('kendama_catch_simulation.xml'), common.ASSETS
예제 #22
0
def get_model_and_assets():
    """Returns a tuple containing the model XML string and a dict of assets."""
    return common.read_model("cheetah.xml"), common.ASSETS
예제 #23
0
def get_model_and_assets_from_setting_kwargs(model_fname, setting_kwargs=None):
    """"Returns a tuple containing the model XML string and a dict of assets."""
    assets = {
        filename: resources.GetResource(os.path.join(_SUITE_DIR, filename))
        for filename in _FILENAMES
    }

    if setting_kwargs is None:
        return common.read_model(model_fname), assets

    # Convert XML to dicts
    model = xmltodict.parse(common.read_model(model_fname))
    materials = xmltodict.parse(assets['./common/materials.xml'])
    skybox = xmltodict.parse(assets['./common/skybox.xml'])

    # Edit grid floor
    if 'grid_rgb1' in setting_kwargs:
        assert isinstance(setting_kwargs['grid_rgb1'],
                          (list, tuple, np.ndarray))
        assert materials['mujoco']['asset']['texture']['@name'] == 'grid'
        materials['mujoco']['asset']['texture']['@rgb1'] = \
            f'{setting_kwargs["grid_rgb1"][0]} {setting_kwargs["grid_rgb1"][1]} {setting_kwargs["grid_rgb1"][2]}'
    if 'grid_rgb2' in setting_kwargs:
        assert isinstance(setting_kwargs['grid_rgb2'],
                          (list, tuple, np.ndarray))
        assert materials['mujoco']['asset']['texture']['@name'] == 'grid'
        materials['mujoco']['asset']['texture']['@rgb2'] = \
            f'{setting_kwargs["grid_rgb2"][0]} {setting_kwargs["grid_rgb2"][1]} {setting_kwargs["grid_rgb2"][2]}'
    if 'grid_markrgb' in setting_kwargs:
        assert isinstance(setting_kwargs['grid_markrgb'],
                          (list, tuple, np.ndarray))
        assert materials['mujoco']['asset']['texture']['@name'] == 'grid'
        materials['mujoco']['asset']['texture']['@markrgb'] = \
            f'{setting_kwargs["grid_markrgb"][0]} {setting_kwargs["grid_markrgb"][1]} {setting_kwargs["grid_markrgb"][2]}'
    if 'grid_texrepeat' in setting_kwargs:
        assert isinstance(setting_kwargs['grid_texrepeat'],
                          (list, tuple, np.ndarray))
        assert materials['mujoco']['asset']['texture']['@name'] == 'grid'
        materials['mujoco']['asset']['material'][0]['@texrepeat'] = \
            f'{setting_kwargs["grid_texrepeat"][0]} {setting_kwargs["grid_texrepeat"][1]}'

    # Edit self
    if 'self_rgb' in setting_kwargs:
        assert isinstance(setting_kwargs['self_rgb'],
                          (list, tuple, np.ndarray))
        assert materials['mujoco']['asset']['material'][1]['@name'] == 'self'
        materials['mujoco']['asset']['material'][1]['@rgba'] = \
            f'{setting_kwargs["self_rgb"][0]} {setting_kwargs["self_rgb"][1]} {setting_kwargs["self_rgb"][2]} 1'

    # Edit skybox
    if 'skybox_rgb' in setting_kwargs:
        assert isinstance(setting_kwargs['skybox_rgb'],
                          (list, tuple, np.ndarray))
        assert skybox['mujoco']['asset']['texture']['@name'] == 'skybox'
        skybox['mujoco']['asset']['texture']['@rgb1'] = \
            f'{setting_kwargs["skybox_rgb"][0]} {setting_kwargs["skybox_rgb"][1]} {setting_kwargs["skybox_rgb"][2]}'
    if 'skybox_rgb2' in setting_kwargs:
        assert isinstance(setting_kwargs['skybox_rgb2'],
                          (list, tuple, np.ndarray))
        assert skybox['mujoco']['asset']['texture']['@name'] == 'skybox'
        skybox['mujoco']['asset']['texture']['@rgb2'] = \
            f'{setting_kwargs["skybox_rgb2"][0]} {setting_kwargs["skybox_rgb2"][1]} {setting_kwargs["skybox_rgb2"][2]}'
    if 'skybox_markrgb' in setting_kwargs:
        assert isinstance(setting_kwargs['skybox_markrgb'],
                          (list, tuple, np.ndarray))
        assert skybox['mujoco']['asset']['texture']['@name'] == 'skybox'
        skybox['mujoco']['asset']['texture']['@markrgb'] = \
            f'{setting_kwargs["skybox_markrgb"][0]} {setting_kwargs["skybox_markrgb"][1]} {setting_kwargs["skybox_markrgb"][2]}'

    # Convert back to XML
    model_xml = xmltodict.unparse(model)
    assets['./common/materials.xml'] = xmltodict.unparse(materials)
    assets['./common/skybox.xml'] = xmltodict.unparse(skybox)

    return model_xml, assets
예제 #24
0
def get_model_and_assets():
    """Returns a tuple containing the model XML string and a dict of assets."""
    return common.read_model(FLAGS.model_filename), common.ASSETS
예제 #25
0
def get_model_and_assets():
    """Returns a tuple containing the model XML string and a dict of assets."""
    return common.read_model('cloth_v4.xml'), common.ASSETS
예제 #26
0
def get_model_and_assets():
  """Returns a tuple containing the model XML string and a dictionary of assets."""
  return common.read_model('humanoid_CMU.xml'), common.ASSETS