예제 #1
0
 def __init__(self, interface):
     """ Initialization.
         
         :param  interface      previous interface to wrap on"""
     # In `__init__()` of derived class, one should firstly call
     #   super(self.__class__, self).__init__(interface)
     if interface is None:
         self.inter = RawInt()
     else:
         self.inter = interface
     assert isinstance(self.inter, RawInt)
예제 #2
0
def make_sc2full_v8_interface(zstat_data_src='',
                              mmr=3500,
                              dict_space=False,
                              zstat_presort_order_name=None,
                              zmaker_version='v4',
                              output_map_size=(128, 128),
                              **kwargs):
    from arena.interfaces.sc2full_formal.obs_int import FullObsIntV7
    from arena.interfaces.sc2full_formal.act_int import FullActIntV6, NoopActIntV4
    from arena.interfaces.raw_int import RawInt
    from arena.interfaces.sc2full_formal.obs_int import ActAsObsSC2
    noop_nums = [i + 1 for i in range(128)]
    inter = RawInt()
    # this obs inter requires game core 4.10.0
    inter = FullObsIntV7(inter,
                         zstat_data_src=zstat_data_src,
                         mmr=mmr,
                         dict_space=dict_space,
                         zstat_presort_order_name=zstat_presort_order_name,
                         game_version='4.10.0',
                         zmaker_version=zmaker_version,
                         output_map_resolution=output_map_size,
                         **kwargs)
    inter = FullActIntV6(inter,
                         max_noop_num=len(noop_nums),
                         map_resolution=output_map_size,
                         dict_space=dict_space,
                         **kwargs)
    inter = ActAsObsSC2(inter)
    noop_func = lambda x: x['A_NOOP_NUM'] if dict_space else x[1]
    inter = NoopActIntV4(inter, noop_nums=noop_nums, noop_func=noop_func)
    return inter
예제 #3
0
 def __init__(self, agent, interface=RawInt(), step_mul=1):
     super(AgtIntWrapper, self).__init__()
     self.agent = agent
     self.inter = interface
     self.step_mul = step_mul
     self.act = None
     assert isinstance(self.inter, RawInt)
예제 #4
0
    def __init__(self, interface, sub_interfaces=[]):
        """ Initialization.
            
        :param  interface        previous interface to wrap on
        :param  sub_interfaces   interfaces to combine
        """

        if interface is None:
            self.inter = RawInt()
        else:
            self.inter = interface
        assert isinstance(self.inter, RawInt)

        self.sub_interfaces = list(sub_interfaces)
        for i, interface in enumerate(sub_interfaces):
            if interface is None:
                self.sub_interfaces[i] = RawInt()
            else:
                assert isinstance(interface, RawInt)
예제 #5
0
 def _install_interfaces(i_agent):
   inter = RawInt()
   inter = ReshapedFrameObsInt(inter, env.envs[i_agent])
   inter = Discrete6ActionInt(inter)
   return inter
예제 #6
0
def make_sc2full_v8_interface(zstat_data_src='',
                              mmr=3500,
                              max_bo_count=50,
                              max_bobt_count=50,
                              dict_space=False,
                              verbose=0,
                              zstat_presort_order_name=None,
                              correct_pos_radius=2.0,
                              correct_building_pos=False,
                              zmaker_version='v4',
                              inj_larv_rule=False,
                              ban_zb_rule=False,
                              ban_rr_rule=False,
                              ban_hydra_rule=False,
                              rr_food_cap=40,
                              zb_food_cap=10,
                              hydra_food_cap=10,
                              mof_lair_rule=False,
                              hydra_spire_rule=False,
                              overseer_rule=False,
                              expl_map_rule=False,
                              baneling_rule=False,
                              add_cargo_to_units=False,
                              output_map_size=(128, 128),
                              crop_to_playable_area=False,
                              ab_dropout_list=None,
                              **kwargs):
    from arena.interfaces.sc2full_formal.obs_int import FullObsIntV7
    from arena.interfaces.sc2full_formal.act_int import FullActIntV6, NoopActIntV4
    from arena.interfaces.raw_int import RawInt
    from arena.interfaces.sc2full_formal.obs_int import ActAsObsSC2
    noop_nums = [i + 1 for i in range(128)]
    inter = RawInt()
    # this obs inter requires game core 4.10.0
    inter = FullObsIntV7(inter,
                         zstat_data_src=zstat_data_src,
                         mmr=mmr,
                         max_bo_count=max_bo_count,
                         max_bobt_count=max_bobt_count,
                         dict_space=dict_space,
                         zstat_presort_order_name=zstat_presort_order_name,
                         game_version='4.10.0',
                         zmaker_version=zmaker_version,
                         inj_larv_rule=inj_larv_rule,
                         ban_zb_rule=ban_zb_rule,
                         ban_rr_rule=ban_rr_rule,
                         ban_hydra_rule=ban_hydra_rule,
                         rr_food_cap=rr_food_cap,
                         zb_food_cap=zb_food_cap,
                         hydra_food_cap=hydra_food_cap,
                         mof_lair_rule=mof_lair_rule,
                         hydra_spire_rule=hydra_spire_rule,
                         overseer_rule=overseer_rule,
                         expl_map_rule=expl_map_rule,
                         baneling_rule=baneling_rule,
                         add_cargo_to_units=add_cargo_to_units,
                         output_map_resolution=output_map_size,
                         crop_to_playable_area=crop_to_playable_area,
                         ab_dropout_list=ab_dropout_list)
    inter = FullActIntV6(inter,
                         max_noop_num=len(noop_nums),
                         correct_pos_radius=correct_pos_radius,
                         correct_building_pos=correct_building_pos,
                         map_resolution=output_map_size,
                         crop_to_playable_area=crop_to_playable_area,
                         dict_space=dict_space,
                         verbose=verbose)
    inter = ActAsObsSC2(inter)
    noop_func = lambda x: x['A_NOOP_NUM'] if dict_space else x[1]
    inter = NoopActIntV4(inter, noop_nums=noop_nums, noop_func=noop_func)
    return inter
예제 #7
0
class Interface(RawInt):
    """
        Interface class

    """
    inter = None

    def __init__(self, interface):
        """ Initialization.
            
            :param  interface      previous interface to wrap on"""
        # In `__init__()` of derived class, one should firstly call
        #   super(self.__class__, self).__init__(interface)
        if interface is None:
            self.inter = RawInt()
        else:
            self.inter = interface
        assert isinstance(self.inter, RawInt)

    def reset(self, obs, **kwargs):
        """ Reset this interface.
            For some reasons, obs space and action space may be specified on reset().
            
            :param  obs            input obs (received by the root interface)"""
        # In `reset()` of derived class, one should firstly call
        #   super(self.__class__, self).reset(obs)
        self.inter.reset(obs, **kwargs)

    def obs_trans(self, obs):
        """ Observation Transformation. This is a recursive call. """
        obs = self.inter.obs_trans(obs)
        # Implement customized obs_trans here in derived class
        return obs

    def act_trans(self, act):
        """ Action Transformation. This is a recursive call. """
        # TODO(peng): raise NotImplementedError, encourage recursive call in derived class
        # Implement customized act_trans here in derived class
        act = self.inter.act_trans(act)
        return act

    def unwrapped(self):
        """ Get the root instance.
            This is usually used for storing global information.
            For example, raw obs and raw act are saved by RawInt(). """
        return self.inter.unwrapped()

    @property
    def observation_space(self):
        """ Observation Space, calculated in a recursive manner.
            Implement customized observation_space here in derived class. """
        return self.inter.observation_space

    @property
    def action_space(self):
        """ Action Space, calculated in a recursive manner.
            Implement customized action_space here in derived class """
        return self.inter.action_space

    def setup(self, observation_space, action_space):
        self.unwrapped().setup(observation_space, action_space)

    def __str__(self):
        """ Get the name of all stacked interface. """
        # TODO(peng): return my_name + '<' + wrapped_interface_name + '>'
        s = str(self.inter)
        return s + '<' + self.__class__.__name__ + '>'
예제 #8
0
class Combine(Interface):
    """
        Concat several Interface to form a new Interface

    """
    inter = None

    def __init__(self, interface, sub_interfaces=[]):
        """ Initialization.
            
        :param  interface        previous interface to wrap on
        :param  sub_interfaces   interfaces to combine
        """

        if interface is None:
            self.inter = RawInt()
        else:
            self.inter = interface
        assert isinstance(self.inter, RawInt)

        self.sub_interfaces = list(sub_interfaces)
        for i, interface in enumerate(sub_interfaces):
            if interface is None:
                self.sub_interfaces[i] = RawInt()
            else:
                assert isinstance(interface, RawInt)

    def setup(self, observation_space, action_space):
        self.unwrapped().setup(observation_space, action_space)
        for i in range(len(self.sub_interfaces)):
            self.sub_interfaces[i].setup(observation_space.spaces[i],
                                         action_space.spaces[i])

    def reset(self, obs, **kwargs):
        inter_ob_sp = self.inter.observation_space
        inter_ac_sp = self.inter.action_space
        assert isinstance(inter_ob_sp, spaces.Tuple)
        assert isinstance(inter_ac_sp, spaces.Tuple)
        assert len(inter_ob_sp.spaces) == len(self.sub_interfaces)
        self.inter.reset(obs)
        for i in range(len(self.sub_interfaces)):
            self.sub_interfaces[i].setup(inter_ob_sp.spaces[i],
                                         inter_ac_sp.spaces[i])
            self.sub_interfaces[i].reset(obs[i])

    def obs_trans(self, obs):
        obs = self.inter.obs_trans(obs)
        sub_obs = tuple([
            sub_inter.obs_trans(ob)
            for ob, sub_inter in zip(obs, self.sub_interfaces)
        ])
        return self._obs_trans(obs, sub_obs)

    def _obs_trans(self, obs, sub_obs):
        """ Observation Transformation.
            obs is observation from self.inter
            sub_obs are observations from sub_interfaces"""
        return sub_obs

    def _act_trans(self, act):
        act = [
            sub_inter.act_trans(ac)
            for ac, sub_inter in zip(act, self.sub_interfaces)
        ]
        return act

    @property
    def observation_space(self):
        return spaces.Tuple(
            [inter.observation_space for inter in self.sub_interfaces])

    @property
    def action_space(self):
        return spaces.Tuple(
            [inter.action_space for inter in self.sub_interfaces])

    def __str__(self):
        """ Get the name of all stacked interface. """
        s = str(self.inter)
        combine_s = str([str(inter) for inter in self.sub_interfaces])
        return s + '<' + self.__class__.__name__ + combine_s + '>'