Пример #1
0
  def __init__(
      self, arena, arm, hand, prop, obs_settings, workspace, control_timestep):
    """Initializes a new `Reach` task.

    Args:
      arena: `composer.Entity` instance.
      arm: `robot_base.RobotArm` instance.
      hand: `robot_base.RobotHand` instance.
      prop: `composer.Entity` instance specifying the prop to reach to, or None
        in which case the target is a fixed site whose position is specified by
        the workspace.
      obs_settings: `observations.ObservationSettings` instance.
      workspace: `_ReachWorkspace` specifying the placement of the prop and TCP.
      control_timestep: Float specifying the control timestep in seconds.
    """
    self._arena = arena
    self._arm = arm
    self._hand = hand
    self._arm.attach(self._hand)
    self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
    self.control_timestep = control_timestep
    self._tcp_initializer = initializers.ToolCenterPointInitializer(
        self._hand, self._arm,
        position=distributions.Uniform(*workspace.tcp_bbox),
        quaternion=workspaces.DOWN_QUATERNION)

    # Add custom camera observable.
    self._task_observables = cameras.add_camera_observables(
        arena, obs_settings, cameras.FRONT_CLOSE)

    target_pos_distribution = distributions.Uniform(*workspace.target_bbox)
    self._prop = prop
    if prop:
      # The prop itself is used to visualize the target location.
      self._make_target_site(parent_entity=prop, visible=False)
      self._target = self._arena.add_free_entity(prop)
      self._prop_placer = initializers.PropPlacer(
          props=[prop],
          position=target_pos_distribution,
          quaternion=workspaces.uniform_z_rotation,
          settle_physics=True)
    else:
      self._target = self._make_target_site(parent_entity=arena, visible=True)
      self._target_placer = target_pos_distribution

      obs = observable.MJCFFeature('pos', self._target)
      obs.configure(**obs_settings.prop_pose._asdict())
      self._task_observables['target_position'] = obs

    # Add sites for visualizing the prop and target bounding boxes.
    workspaces.add_bbox_site(
        body=self.root_entity.mjcf_model.worldbody,
        lower=workspace.tcp_bbox.lower, upper=workspace.tcp_bbox.upper,
        rgba=constants.GREEN, name='tcp_spawn_area')
    workspaces.add_bbox_site(
        body=self.root_entity.mjcf_model.worldbody,
        lower=workspace.target_bbox.lower, upper=workspace.target_bbox.upper,
        rgba=constants.BLUE, name='target_spawn_area')
Пример #2
0
    def __init__(self, arena, arm, hand, prop, obs_settings, workspace,
                 control_timestep):
        """Initializes a new `Lift` task.

    Args:
      arena: `composer.Entity` instance.
      arm: `robot_base.RobotArm` instance.
      hand: `robot_base.RobotHand` instance.
      prop: `composer.Entity` instance.
      obs_settings: `observations.ObservationSettings` instance.
      workspace: `_LiftWorkspace` specifying the placement of the prop and TCP.
      control_timestep: Float specifying the control timestep in seconds.
    """
        self._arena = arena
        self._arm = arm
        self._hand = hand
        self._arm.attach(self._hand)
        self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
        self.control_timestep = control_timestep

        # Add custom camera observable.
        self._task_observables = cameras.add_camera_observables(
            arena, obs_settings, cameras.FRONT_CLOSE)

        self._tcp_initializer = initializers.ToolCenterPointInitializer(
            self._hand,
            self._arm,
            position=distributions.Uniform(*workspace.tcp_bbox),
            quaternion=workspaces.DOWN_QUATERNION)

        self._prop = prop
        self._arena.add_free_entity(prop)
        self._prop_placer = initializers.PropPlacer(
            props=[prop],
            position=distributions.Uniform(*workspace.prop_bbox),
            quaternion=workspaces.uniform_z_rotation,
            ignore_collisions=True,
            settle_physics=True)

        # Add sites for visualizing bounding boxes and target height.
        self._target_height_site = workspaces.add_bbox_site(
            body=self.root_entity.mjcf_model.worldbody,
            lower=(-1, -1, 0),
            upper=(1, 1, 0),
            rgba=constants.RED,
            name='target_height')
        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.tcp_bbox.lower,
                                 upper=workspace.tcp_bbox.upper,
                                 rgba=constants.GREEN,
                                 name='tcp_spawn_area')
        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.prop_bbox.lower,
                                 upper=workspace.prop_bbox.upper,
                                 rgba=constants.BLUE,
                                 name='prop_spawn_area')
    def __init__(self, observation_settings, opponent, game_logic, board,
                 markers):
        """Initializes the task.

    Args:
      observation_settings: An `observations.ObservationSettings` namedtuple
        specifying configuration options for each category of observation.
      opponent: Opponent used for generating opponent moves.
      game_logic: Logic for keeping track of the logical state of the board.
      board: Board to use.
      markers: Markers to use.
    """
        self._game_logic = game_logic
        self._game_opponent = opponent
        arena = arenas.Standard(observable_options=observations.make_options(
            observation_settings, observations.ARENA_OBSERVABLES))
        arena.attach(board)
        arm = kinova.JacoArm(observable_options=observations.make_options(
            observation_settings, observations.JACO_ARM_OBSERVABLES))
        hand = kinova.JacoHand(observable_options=observations.make_options(
            observation_settings, observations.JACO_HAND_OBSERVABLES))
        arm.attach(hand)
        arena.attach_offset(arm, offset=(0, _ARM_Y_OFFSET, 0))
        arena.attach(markers)

        # Geoms belonging to the arm and hand are placed in a custom group in order
        # to disable their visibility to the top-down camera. NB: we assume that
        # there are no other geoms in ROBOT_GEOM_GROUP that don't belong to the
        # robot (this is usually the case since the default geom group is 0). If
        # there are then these will also be invisible to the top-down camera.
        for robot_geom in arm.mjcf_model.find_all('geom'):
            robot_geom.group = arenas.ROBOT_GEOM_GROUP

        self._arena = arena
        self._board = board
        self._arm = arm
        self._hand = hand
        self._markers = markers
        self._tcp_initializer = initializers.ToolCenterPointInitializer(
            hand=hand,
            arm=arm,
            position=distributions.Uniform(_TCP_LOWER_BOUNDS,
                                           _TCP_UPPER_BOUNDS),
            quaternion=_uniform_downward_rotation())

        # Add an observable exposing the logical state of the board.
        board_state_observable = observable.Generic(
            lambda physics: self._game_logic.get_board_state())
        board_state_observable.configure(
            **observation_settings.board_state._asdict())
        self._task_observables = {'board_state': board_state_observable}
Пример #4
0
    def __init__(self, arena, arm, hand, num_bricks, obs_settings, workspace,
                 control_timestep):
        if not 2 <= num_bricks <= 6:
            raise ValueError(
                '`num_bricks` must be between 2 and 6, got {}.'.format(
                    num_bricks))

        if num_bricks > 3:
            # The default values computed by MuJoCo's compiler are too small if there
            # are more than three stacked bricks, since each stacked pair generates
            # a large number of contacts. The values below are sufficient for up to
            # 6 stacked bricks.
            # TODO(b/78331644): It may be useful to log the size of `physics.model`
            #                   and `physics.data` after compilation to gauge the
            #                   impact of these changes on MuJoCo's memory footprint.
            arena.mjcf_model.size.nconmax = 400
            arena.mjcf_model.size.njmax = 1200

        self._arena = arena
        self._arm = arm
        self._hand = hand
        self._arm.attach(self._hand)
        self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
        self.control_timestep = control_timestep

        # Add custom camera observable.
        self._task_observables = cameras.add_camera_observables(
            arena, obs_settings, cameras.FRONT_CLOSE)

        color_sequence = iter(_COLOR_VALUES)
        brick_obs_options = observations.make_options(
            obs_settings, observations.FREEPROP_OBSERVABLES)

        bricks = []
        brick_frames = []
        goal_hint_bricks = []
        for _ in range(num_bricks):
            color = next(color_sequence)
            brick = props.Duplo(color=color,
                                observable_options=brick_obs_options)
            brick_frames.append(arena.add_free_entity(brick))
            bricks.append(brick)

            # Translucent, contactless brick with no observables. These are used to
            # provide a visual hint representing the goal state for each task.
            hint_brick = props.Duplo(color=color)
            _hintify(hint_brick, alpha=_HINT_ALPHA)
            arena.attach(hint_brick)
            goal_hint_bricks.append(hint_brick)

        self._bricks = bricks
        self._brick_frames = brick_frames
        self._goal_hint_bricks = goal_hint_bricks

        # Position and quaternion for the goal hint.
        self._goal_hint_pos = workspace.goal_hint_pos
        self._goal_hint_quat = workspace.goal_hint_quat

        self._tcp_initializer = initializers.ToolCenterPointInitializer(
            self._hand,
            self._arm,
            position=distributions.Uniform(*workspace.tcp_bbox),
            quaternion=workspaces.DOWN_QUATERNION)

        # Add sites for visual debugging.
        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.tcp_bbox.lower,
                                 upper=workspace.tcp_bbox.upper,
                                 rgba=constants.GREEN,
                                 name='tcp_spawn_area')

        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.prop_bbox.lower,
                                 upper=workspace.prop_bbox.upper,
                                 rgba=constants.BLUE,
                                 name='prop_spawn_area')
Пример #5
0
    def __init__(self, arena, arm, hand, prop, obs_settings, workspace,
                 control_timestep, table_col_tag):
        """Initializes a new `Reach` task.

    Args:
      arena: `composer.Entity` instance.
      arm: `robot_base.RobotArm` instance.
      hand: `robot_base.RobotHand` instance.
      prop: `composer.Entity` instance specifying the prop to reach to, or None
        in which case the target is a fixed site whose position is specified by
        the workspace.
      obs_settings: `observations.ObservationSettings` instance.
      workspace: `_ReachWorkspace` specifying the placement of the prop and TCP.
      control_timestep: Float specifying the control timestep in seconds.
    """
        self._arena = arena
        self._arm = arm
        self._hand = hand
        self._arm.attach(self._hand)
        self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
        self.control_timestep = control_timestep
        self._tcp_initializer = initializers.ToolCenterPointInitializer(
            self._hand,
            self._arm,
            position=distributions.Uniform(*workspace.tcp_bbox),
            quaternion=workspaces.DOWN_QUATERNION)

        # Add custom camera observable.
        self._task_observables = cameras.add_camera_observables(
            arena, obs_settings, cameras.FRONT_FAR)

        target_pos_distribution = distributions.Uniform(*workspace.target_bbox)
        self._prop = prop
        if prop:
            # The prop itself is used to visualize the target location.
            self._make_target_site(parent_entity=prop, visible=False)
            self._target = self._arena.add_free_entity(prop)
            self._prop_placer = initializers.PropPlacer(
                props=[prop],
                position=target_pos_distribution,
                quaternion=workspaces.uniform_z_rotation,
                settle_physics=True)
        else:
            self._target = self._make_target_site(parent_entity=arena,
                                                  visible=True)
            self._target_placer = target_pos_distribution

            obs = observable.MJCFFeature('pos', self._target)
            obs.configure(**obs_settings.prop_pose._asdict())
            self._task_observables['target_position'] = obs

        # Randomize the table surface
        if table_col_tag == 0:
            self.root_entity.mjcf_model.worldbody.add('geom',
                                                      type='plane',
                                                      pos="0 0 0.01",
                                                      size="0.6 0.6 0.5",
                                                      rgba=".6 .6 .5 1",
                                                      contype="1",
                                                      conaffinity="1",
                                                      friction="2 0.1 0.002",
                                                      material="j2s7/robot_bw")
        elif table_col_tag == 1:
            self.root_entity.mjcf_model.worldbody.add('geom',
                                                      type='plane',
                                                      pos="0 0 0.01",
                                                      size="0.6 0.6 0.5",
                                                      rgba=".6 .6 .5 1",
                                                      contype="1",
                                                      conaffinity="1",
                                                      friction="2 0.1 0.002",
                                                      material="j2s7/robot_bw")

        elif table_col_tag == 2:
            self.root_entity.mjcf_model.worldbody.add(
                'geom',
                type='plane',
                pos="0 0 0.01",
                size="0.6 0.6 0.5",
                rgba=".6 .6 .5 1",
                contype="1",
                conaffinity="1",
                friction="2 0.1 0.002",
                material="j2s7/real_desk")

        # Blue stary sky
        self.root_entity.mjcf_model.asset.texture[0].mark = 'random'
        self.root_entity.mjcf_model.asset.texture[0].rgb1 = np.array(
            [.4, .6, .8])
        self.root_entity.mjcf_model.asset.texture[0].width = 800
        self.root_entity.mjcf_model.asset.texture[0].height = 800

        # How to remove the checkerboard groundplane ?
        # For now this somehow sets it to a white plane!
        self.root_entity.mjcf_model.asset.texture[1].width = 1
        self.root_entity.mjcf_model.asset.texture[1].height = 1
        # Add sites for visualizing the prop and target bounding boxes.
        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.tcp_bbox.lower,
                                 upper=workspace.tcp_bbox.upper,
                                 rgba=constants.GREEN,
                                 name='tcp_spawn_area')
        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.target_bbox.lower,
                                 upper=workspace.target_bbox.upper,
                                 rgba=constants.BLUE,
                                 name='target_spawn_area')
Пример #6
0
    def __init__(self, arena, arm, hand, prop, obs_settings, workspace,
                 control_timestep, table_col_tag, sky_col_tag):
        """Initializes a new `Reach` task.

    Args:
      arena: `composer.Entity` instance.
      arm: `robot_base.RobotArm` instance.
      hand: `robot_base.RobotHand` instance.
      prop: `composer.Entity` instance specifying the prop to reach to, or None
        in which case the target is a fixed site whose position is specified by
        the workspace.
      obs_settings: `observations.ObservationSettings` instance.
      workspace: `_ReachWorkspace` specifying the placement of the prop and TCP.
      control_timestep: Float specifying the control timestep in seconds.
    """
        self._arena = arena
        self._arm = arm
        self._hand = hand
        self._arm.attach(self._hand)
        self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
        self.control_timestep = control_timestep
        self._tcp_initializer = initializers.ToolCenterPointInitializer(
            self._hand,
            self._arm,
            position=distributions.Uniform(*workspace.tcp_bbox),
            quaternion=workspaces.DOWN_QUATERNION)

        # Add custom camera observable.
        # Randomize camera viewing angle during training
        # Experimental
        randomize_cam = False

        if randomize_cam:
            camera_angle = random.choice([0, 1, 2, 3])
            if camera_angle == 0:
                self._task_observables = cameras.add_camera_observables(
                    arena, obs_settings, cameras.FRONT_FAR)
            elif camera_angle == 1:
                self._task_observables = cameras.add_camera_observables(
                    arena, obs_settings, cameras.FRONT_CLOSE)
            elif camera_angle == 2:
                self._task_observables = cameras.add_camera_observables(
                    arena, obs_settings, cameras.FRONT_CLOSE_TILT_UP)
            elif camera_angle == 3:
                self._task_observables = cameras.add_camera_observables(
                    arena, obs_settings, cameras.FRONT_CLOSE_TILT_DOWN)
        else:
            self._task_observables = cameras.add_camera_observables(
                arena, obs_settings, cameras.FRONT_FAR)

        target_pos_distribution = distributions.Uniform(*workspace.target_bbox)
        self._prop = prop
        if prop:
            # The prop itself is used to visualize the target location.
            self._make_target_site(parent_entity=prop, visible=False)
            self._target = self._arena.add_free_entity(prop)
            self._prop_placer = initializers.PropPlacer(
                props=[prop],
                position=target_pos_distribution,
                quaternion=workspaces.uniform_z_rotation,
                settle_physics=True)
        else:
            self._target = self._make_target_site(parent_entity=arena,
                                                  visible=True)
            self._target_placer = target_pos_distribution

            obs = observable.MJCFFeature('pos', self._target)
            obs.configure(**obs_settings.prop_pose._asdict())
            self._task_observables['target_position'] = obs

        # Randomize the table surface
        if table_col_tag == 0:
            self.root_entity.mjcf_model.worldbody.add('geom',
                                                      type='plane',
                                                      pos="0 0 0.01",
                                                      size="0.6 0.6 0.5",
                                                      rgba=".6 .6 .5 1",
                                                      contype="1",
                                                      conaffinity="1",
                                                      friction="2 0.1 0.002",
                                                      material="j2s7/darkwood")
        elif table_col_tag == 1:
            self.root_entity.mjcf_model.worldbody.add('geom',
                                                      type='plane',
                                                      pos="0 0 0.01",
                                                      size="0.6 0.6 0.5",
                                                      rgba=".6 .6 .5 1",
                                                      contype="1",
                                                      conaffinity="1",
                                                      friction="2 0.1 0.002",
                                                      material="j2s7/marble")
        elif table_col_tag == 2:
            self.root_entity.mjcf_model.worldbody.add(
                'geom',
                type='plane',
                pos="0 0 0.01",
                size="0.6 0.6 0.5",
                rgba=".6 .6 .5 1",
                contype="1",
                conaffinity="1",
                friction="2 0.1 0.002",
                material="j2s7/navy_blue")
        elif table_col_tag == 3:
            self.root_entity.mjcf_model.worldbody.add('geom',
                                                      type='plane',
                                                      pos="0 0 0.01",
                                                      size="0.6 0.6 0.5",
                                                      rgba=".6 .6 .5 1",
                                                      contype="1",
                                                      conaffinity="1",
                                                      friction="2 0.1 0.002",
                                                      material="j2s7/tennis")
        elif table_col_tag == 4:
            self.root_entity.mjcf_model.worldbody.add('geom',
                                                      type='plane',
                                                      pos="0 0 0.01",
                                                      size="0.6 0.6 0.5",
                                                      rgba=".6 .6 .5 1",
                                                      contype="1",
                                                      conaffinity="1",
                                                      friction="2 0.1 0.002",
                                                      material="j2s7/wood")
        elif table_col_tag == 5:
            self.root_entity.mjcf_model.worldbody.add(
                'geom',
                type='plane',
                pos="0 0 0.01",
                size="0.6 0.6 0.5",
                rgba=".6 .6 .5 1",
                contype="1",
                conaffinity="1",
                friction="2 0.1 0.002",
                material="j2s7/wood_light")

        elif table_col_tag == 6:
            self.root_entity.mjcf_model.worldbody.add(
                'geom',
                type='plane',
                pos="0 0 0.01",
                size="0.6 0.6 0.5",
                rgba=".6 .6 .5 1",
                contype="1",
                conaffinity="1",
                friction="2 0.1 0.002",
                material="j2s7/light_wood_v2")

        elif table_col_tag == 7:
            self.root_entity.mjcf_model.worldbody.add('geom',
                                                      type='plane',
                                                      pos="0 0 0.01",
                                                      size="0.6 0.6 0.5",
                                                      rgba=".6 .6 .5 1",
                                                      contype="1",
                                                      conaffinity="1",
                                                      friction="2 0.1 0.002",
                                                      material="j2s7/metal")

        elif table_col_tag == 8:
            self.root_entity.mjcf_model.worldbody.add('geom',
                                                      type='plane',
                                                      pos="0 0 0.01",
                                                      size="0.6 0.6 0.5",
                                                      rgba=".6 .6 .5 1",
                                                      contype="1",
                                                      conaffinity="1",
                                                      friction="2 0.1 0.002",
                                                      material="j2s7/grass")

        elif table_col_tag == 9:
            self.root_entity.mjcf_model.worldbody.add(
                'geom',
                type='plane',
                pos="0 0 0.01",
                size="0.6 0.6 0.5",
                rgba=".6 .6 .5 1",
                contype="1",
                conaffinity="1",
                friction="2 0.1 0.002",
                material="j2s7/blue_cloud")

        elif table_col_tag == 10:
            self.root_entity.mjcf_model.worldbody.add(
                'geom',
                type='plane',
                pos="0 0 0.01",
                size="0.6 0.6 0.5",
                rgba=".6 .6 .5 1",
                contype="1",
                conaffinity="1",
                friction="2 0.1 0.002",
                material="j2s7/marble_v2")

        elif table_col_tag == 11:
            self.root_entity.mjcf_model.worldbody.add(
                'geom',
                type='plane',
                pos="0 0 0.01",
                size="0.6 0.6 0.5",
                rgba=".6 .6 .5 1",
                contype="1",
                conaffinity="1",
                friction="2 0.1 0.002",
                material="j2s7/wood_gray")
        # Sky change of colour
        if sky_col_tag == 1:
            # Red stary sky
            self.root_entity.mjcf_model.asset.texture[0].mark = 'random'
            self.root_entity.mjcf_model.asset.texture[0].rgb1 = np.array(
                [.8, .1, .4])
            self.root_entity.mjcf_model.asset.texture[0].width = 800
            self.root_entity.mjcf_model.asset.texture[0].height = 800

        elif sky_col_tag == 2:
            # Orange stary sky
            self.root_entity.mjcf_model.asset.texture[0].mark = 'random'
            self.root_entity.mjcf_model.asset.texture[0].rgb1 = np.array(
                [.8, .5, .1])
            self.root_entity.mjcf_model.asset.texture[0].width = 800
            self.root_entity.mjcf_model.asset.texture[0].height = 800

        elif sky_col_tag == 3:
            # Yellow stary sky
            self.root_entity.mjcf_model.asset.texture[0].mark = 'random'
            self.root_entity.mjcf_model.asset.texture[0].rgb1 = np.array(
                [1, 1, .4])
            self.root_entity.mjcf_model.asset.texture[0].width = 800
            self.root_entity.mjcf_model.asset.texture[0].height = 800

        elif sky_col_tag == 4:
            # Pink stary sky
            self.root_entity.mjcf_model.asset.texture[0].mark = 'random'
            self.root_entity.mjcf_model.asset.texture[0].rgb1 = np.array(
                [1, .5, 1])
            self.root_entity.mjcf_model.asset.texture[0].width = 800
            self.root_entity.mjcf_model.asset.texture[0].height = 800

        elif sky_col_tag == 5:
            # Amber stary sky
            self.root_entity.mjcf_model.asset.texture[0].mark = 'random'
            self.root_entity.mjcf_model.asset.texture[0].rgb1 = np.array(
                [1, .6, .4])
            self.root_entity.mjcf_model.asset.texture[0].width = 800
            self.root_entity.mjcf_model.asset.texture[0].height = 800

        elif sky_col_tag == 6:
            # White stary sky
            self.root_entity.mjcf_model.asset.texture[0].rgb1 = np.array(
                [1., 1., 1.])
            self.root_entity.mjcf_model.asset.texture[0].rgb2 = np.array(
                [1., 1., 1.])
            self.root_entity.mjcf_model.asset.texture[0].width = 800
            self.root_entity.mjcf_model.asset.texture[0].height = 800

        elif sky_col_tag == 7:
            # Black stary sky
            self.root_entity.mjcf_model.asset.texture[0].rgb1 = np.array(
                [0, 0, 0])
            self.root_entity.mjcf_model.asset.texture[0].width = 800
            self.root_entity.mjcf_model.asset.texture[0].height = 800

        elif sky_col_tag == 8:
            # Leave the checkerboard
            self.root_entity.mjcf_model.asset.texture[1].width = 1
            self.root_entity.mjcf_model.asset.texture[1].height = 1

        # TODO: Remove the checkerboard groundplane
        # For now this somehow sets it to a white plane!
        self.root_entity.mjcf_model.asset.texture[1].width = 1
        self.root_entity.mjcf_model.asset.texture[1].height = 1
        # Add sites for visualizing the prop and target bounding boxes.
        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.tcp_bbox.lower,
                                 upper=workspace.tcp_bbox.upper,
                                 rgba=constants.GREEN,
                                 name='tcp_spawn_area')
        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.target_bbox.lower,
                                 upper=workspace.target_bbox.upper,
                                 rgba=constants.BLUE,
                                 name='target_spawn_area')
Пример #7
0
    def __init__(self, arena, arm, hand, prop, obs_settings, workspace,
                 control_timestep, cradle):
        """Initializes a new `Place` task.

    Args:
      arena: `composer.Entity` instance.
      arm: `robot_base.RobotArm` instance.
      hand: `robot_base.RobotHand` instance.
      prop: `composer.Entity` instance.
      obs_settings: `observations.ObservationSettings` instance.
      workspace: A `_PlaceWorkspace` instance.
      control_timestep: Float specifying the control timestep in seconds.
      cradle: `composer.Entity` onto which the `prop` must be placed.
    """
        self._arena = arena
        self._arm = arm
        self._hand = hand
        self._arm.attach(self._hand)
        self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
        self.control_timestep = control_timestep

        # Add custom camera observable.
        self._task_observables = cameras.add_camera_observables(
            arena, obs_settings, cameras.FRONT_CLOSE)

        self._tcp_initializer = initializers.ToolCenterPointInitializer(
            self._hand,
            self._arm,
            position=distributions.Uniform(*workspace.tcp_bbox),
            quaternion=workspaces.DOWN_QUATERNION)

        self._prop = prop
        self._prop_frame = self._arena.add_free_entity(prop)
        self._pedestal = Pedestal(cradle=cradle, target_radius=_TARGET_RADIUS)
        self._arena.attach(self._pedestal)

        for obs in six.itervalues(self._pedestal.observables.as_dict()):
            obs.configure(**obs_settings.prop_pose._asdict())

        self._prop_placer = initializers.PropPlacer(
            props=[prop],
            position=distributions.Uniform(*workspace.prop_bbox),
            quaternion=workspaces.uniform_z_rotation,
            settle_physics=True,
            max_attempts_per_prop=50)

        self._pedestal_placer = initializers.PropPlacer(
            props=[self._pedestal],
            position=distributions.Uniform(*workspace.target_bbox),
            settle_physics=False)

        # Add sites for visual debugging.
        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.tcp_bbox.lower,
                                 upper=workspace.tcp_bbox.upper,
                                 rgba=constants.GREEN,
                                 name='tcp_spawn_area')
        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.prop_bbox.lower,
                                 upper=workspace.prop_bbox.upper,
                                 rgba=constants.BLUE,
                                 name='prop_spawn_area')
        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.target_bbox.lower,
                                 upper=workspace.target_bbox.upper,
                                 rgba=constants.CYAN,
                                 name='pedestal_spawn_area')