示例#1
0
    def __init__(self,
                 game=None,
                 frame_skip=4,  # Frames per step (>=1).
                 num_img_obs=4,  # Number of (past) frames in observation (>=1).

                 clip_reward=True,
                 episodic_lives=True,
                 max_start_noops=30,
                 repeat_action_probability=0.,
                 horizon=27000,):

        if not game:
            game = 'doom_battle'

        cfg=default_cfg(env=game)
        cfg.wide_aspect_ratio = False

        self.env = create_env(game, cfg=cfg)
        self._observation_space = self.env.observation_space

        gym_action_space = self.env.action_space
        self._action_space = IntBox(low=0, high=gym_action_space.n)  # only for discrete space

        self.first_reset = True
    def __init__(
        self,
        game=None,
        frame_skip=4,  # Frames per step (>=1).
        num_img_obs=4,  # Number of (past) frames in observation (>=1).
        clip_reward=True,
        episodic_lives=True,
        max_start_noops=30,
        repeat_action_probability=0.,
        horizon=27000,
    ):

        cfg = default_cfg(env=game)
        cfg.res_w = 96
        cfg.res_h = 72
        cfg.dmlab_throughput_benchmark = True
        cfg.dmlab_renderer = 'software'

        self.env = create_env(game, cfg=cfg)
        self._observation_space = self.env.observation_space

        gym_action_space = self.env.action_space
        self._action_space = IntBox(
            low=0, high=gym_action_space.n)  # only for discrete space
    def __init__(self, scenario_name='', scenario_cfg=None, max_steps=200, 
                 gif_freq=500, steps_per_action=4, image_dir='images/test', 
                 viewport=False):

        # Load holodeck environment
        if scenario_cfg is not None and \
            scenario_cfg['package_name'] not in holodeck.installed_packages():
            
            holodeck.install(scenario_cfg['package_name'])

        self._env = holodeck.make(scenario_name=scenario_name, 
                                  scenario_cfg=scenario_cfg, 
                                  show_viewport=viewport)

        # Get action space from holodeck env and store for use with rlpyt
        if self.is_action_continuous:
            self._action_space = FloatBox(-1, 1, self._env.action_space.shape)

        else:
            self._action_space = IntBox(self._env.action_space.get_low(), 
                                        self._env.action_space.get_high(), 
                                        ())

        # Calculate observation space with all sensor data
        max_width = 0
        max_height = 0
        num_img = 0
        num_lin = 0
        for sensor in self._env._agent.sensors.values():
            if 'Task' in sensor.name:
                continue
            shape = sensor.sensor_data.shape
            if len(shape) == 3:
                max_width = max(max_width, shape[0])
                max_height = max(max_height, shape[1])
                num_img += shape[2]
            else:
                num_lin += np.prod(shape)
        
        if num_img > 0 and num_lin == 0:
            self.has_img = True
            self.has_lin = False
            self._observation_space = FloatBox(0, 1, 
                (num_img, max_width, max_height))
        elif num_lin > 0 and num_img == 0:
            self.has_img = False
            self.has_lin = True
            self._observation_space = FloatBox(-256, 256, (num_lin,))
        else:
            self.has_img = True
            self.has_lin = True
            self._observation_space = Composite([
                FloatBox(0, 1, (num_img, max_width, max_height)),
                FloatBox(-256, 256, (num_lin,))],
                HolodeckObservation)

        # Set data members
        self._max_steps = max_steps
        self._image_dir = image_dir
        self._steps_per_action = steps_per_action
        self.curr_step = 0
        self.gif_freq = gif_freq
        self.rollout_count = -1
        self.gif_images = []
示例#4
0
 def observation_space(self):
     return IntBox(low=0, high=255, shape=(3, ) + self._size, dtype="uint8")
示例#5
0
 def observation_space(self):
     shape = (1 if self._grayscale else 3, ) + self._size
     space = IntBox(low=0, high=255, shape=shape, dtype=np.uint8)
     return space
示例#6
0
    def __init__(
        self,
        assortment_size=1000,  # number of items to train
        max_stock=1000,  # Size of maximum stock
        clip_reward=False,
        episodic_lives=True,
        repeat_action_probability=0.0,
        horizon=10000,
        seed=None,
        substep_count=4,
        bucket_customers=torch.tensor([800., 400., 500., 900.]),
        bucket_cov=torch.eye(4) / 100,
        forecastBias=0.0,
        forecastVariance=0.0,
        freshness=1,
        utility_function='homogeneous',
        utility_weights={
            'alpha': 1.,
            'beta': 1.,
            'gamma': 1.
        },
        characDim=4,
        lead_time=1,  # Defines how quickly the orders goes through the buffer - also impacts the relevance of the observation
        lead_time_fast=0,
        symmetric_action_space=False,
    ):
        save__init__args(locals(), underscore=True)
        logging.info("Creating new StoreEnv")

        self.bucket_customers = bucket_customers

        # Spaces

        if symmetric_action_space:
            self._action_space = FloatBox(low=-max_stock / 2,
                                          high=max_stock / 2,
                                          shape=[assortment_size])
        else:
            self._action_space = IntBox(low=0,
                                        high=max_stock,
                                        shape=[assortment_size])
        self.stock = torch.zeros(assortment_size,
                                 max_stock,
                                 requires_grad=False)

        # correct high with max shelf life

        self._observation_space = FloatBox(
            low=0,
            high=1000,
            shape=(assortment_size,
                   max_stock + characDim + lead_time + lead_time_fast + 1))
        self._horizon = int(horizon)
        self.assortment = Assortment(assortment_size, freshness, seed)
        self._repeater = torch.stack(
            (self.assortment.shelf_lives, torch.zeros(
                self._assortment_size))).transpose(0, 1).reshape(-1).detach()
        self.forecast = torch.zeros(assortment_size, 1)  # DAH forecast.
        self._step_counter = 0

        # Needs to move towards env parameters

        self._customers = \
            d.multivariate_normal.MultivariateNormal(bucket_customers,
                                                     bucket_cov)
        self.assortment.base_demand = \
            self.assortment.base_demand.detach() \
            / bucket_customers.sum()
        self._bias = d.normal.Normal(forecastBias, forecastVariance)

        # We want a yearly seasonality - We have a cosinus argument and a phase.
        # Note that, as we take the absolute value, 2*pi/365 becomes pi/365.

        self._year_multiplier = torch.arange(0.0, horizon, pi / 365)
        self._week_multiplier = torch.arange(0.0, horizon, pi / 7)
        self._phase = 2 * pi * torch.rand(assortment_size)
        self._phase2 = 2 * pi * torch.rand(assortment_size)
        self.create_buffers(lead_time, lead_time_fast)
        if utility_function == 'linear':
            self.utility_function = LinearUtility(**utility_weights)
        elif utility_function == 'loglinear':
            self.utility_function = LogLinearUtility(**utility_weights)
        elif utility_function == 'cobbdouglas':
            self.utility_function = CobbDouglasUtility(**utility_weights)
        elif utility_function == 'homogeneous':
            self.utility_function = HomogeneousReward(**utility_weights)
        else:
            self.utility_function = utility_function
        self._updateEnv()
        for i in range(self._lead_time):
            units_to_order = torch.as_tensor(
                self.forecast.squeeze() * bucket_customers[i]).round().clamp(
                    0, self._max_stock)
            self._addStock(units_to_order)
示例#7
0
 def __init__(
     self,
     level,
     height=72,
     width=96,
     action_repeat=4,
     frame_history=1,
     renderer="hardware",
     fps=None,
     episode_length_seconds=None,
     config_kwargs=None,
     cache_dir="/data/adam/dmlab_cache",
     gpu_device_index="EGL_DEVICE_ID",
 ):
     if level in DMLAB30:
         level = "/contributed/dmlab30/" + level
     level_cache = None if cache_dir is None else LevelCache(cache_dir)
     config = dict(height=str(height), width=str(width))
     if fps is not None:
         config["fps"] = str(fps)
     if episode_length_seconds is not None:
         config["episodeLengthSeconds"] = str(episode_length_seconds)
     if gpu_device_index is not None:
         if isinstance(gpu_device_index, int):
             gpu_device_index = str(gpu_device_index)
         else:
             gpu_device_index = os.environ.get(gpu_device_index, "0")
         config["gpuDeviceIndex"] = gpu_device_index
     if config_kwargs is not None:
         if config.keys() & config_kwargs.keys():
             raise KeyError(f"Had duplicate key(s) in config_kwargs: "
                            f"{config.keys() & config_kwargs.keys()}")
         config.update(config_kwargs)
     self.dmlab_env = deepmind_lab.Lab(
         level=level,
         observations=["RGB"],
         config=config,
         renderer=renderer,
         level_cache=level_cache,
     )
     self._action_map = np.array(
         [
             [0, 0, 0, 1, 0, 0, 0],  # Forward
             [0, 0, 0, -1, 0, 0, 0],  # Backward
             [0, 0, -1, 0, 0, 0, 0],  # Move Left
             [0, 0, 1, 0, 0, 0, 0],  # Move Right
             [-20, 0, 0, 0, 0, 0, 0],  # Look Left
             [20, 0, 0, 0, 0, 0, 0],  # Look Right
             [-20, 0, 0, 1, 0, 0, 0],  # Left Forward
             [20, 0, 0, 1, 0, 0, 0],  # Right Forward
             [0, 0, 0, 0, 1, 0, 0],  # Fire
         ],
         dtype=np.int32)
     self._action_space = IntBox(low=0, high=len(self._action_map))
     self._observation_space = IntBox(low=0,
                                      high=256,
                                      shape=(3 * frame_history, height,
                                             width),
                                      dtype=np.uint8)
     self._zero_obs = np.zeros((3, height, width), dtype=np.uint8)
     if frame_history > 1:
         self._obs_deque = deque(maxlen=frame_history)
     self._frame_history = frame_history
     self._action_repeat = action_repeat
    def __init__(
            self,
            game="pong",
            frame_skip=4,  # Frames per step (>=1).
            num_img_obs=4,  # Number of (past) frames in observation (>=1) - "frame stacking".
            clip_reward=True,
            episodic_lives=True,
            fire_on_reset=False,
            max_start_noops=30,
            repeat_action_probability=0.,
            horizon=27000,
            no_extrinsic=False,
            no_negative_reward=False,
            normalize_obs=False,
            normalize_obs_steps=10000,
            downsampling_scheme='classical',
            record_freq=0,
            record_dir=None,
            score_multiplier=1.0):
        save__init__args(locals(), underscore=True)

        # ALE
        game_path = atari_py.get_game_path(game)
        if not os.path.exists(game_path):
            raise IOError("You asked for game {} but path {} does not "
                          " exist".format(game, game_path))
        self.ale = atari_py.ALEInterface()
        self.ale.setFloat(b'repeat_action_probability',
                          repeat_action_probability)
        self.ale.loadROM(game_path)

        # Spaces
        self._action_set = self.ale.getMinimalActionSet()
        self._action_space = IntBox(low=0, high=len(self._action_set))
        if downsampling_scheme == 'classical':
            self._frame_shape = (84, 84)  # (W, H)
        elif downsampling_scheme == 'new':
            self._frame_shape = (80, 104)
        obs_shape = (num_img_obs, self._frame_shape[1], self._frame_shape[0])
        self._observation_space = IntBox(low=0,
                                         high=255,
                                         shape=obs_shape,
                                         dtype="uint8")
        self._max_frame = self.ale.getScreenGrayscale()
        self._raw_frame_1 = self._max_frame.copy()
        self._raw_frame_2 = self._max_frame.copy()
        self._obs = np.zeros(shape=obs_shape, dtype="uint8")

        # Settings
        self._has_fire = "FIRE" in self.get_action_meanings()
        self._has_up = "UP" in self.get_action_meanings()
        self._horizon = int(horizon)
        self._multiplier = score_multiplier

        # Recording
        self.record_env = False  # set in samping_process for environment 0
        self._record_episode = False
        self._record_freq = record_freq
        self._video_dir = os.path.join(record_dir, 'videos')
        if "TMPDIR" in os.environ:
            self._frames_dir = os.path.join("{}/frames".format(
                os.path.expandvars("$TMPDIR")))
            pathlib.Path(self._frames_dir).mkdir(exist_ok=True)
        else:
            self._frames_dir = os.path.join(self._video_dir, 'frames')
        self._episode_number = 0

        self.reset()
示例#9
0
 def observation_space(self):
     obs = self._env.observation_space
     mission_shape = obs['mission'].shape[0]
     direction_shape = obs['direction'].shape[0]
     return StateObs(obs['image'], IntBox(0, 1, (mission_shape + direction_shape)))
示例#10
0
    def __init__(self,
                 env_name,
                 window_size,
                 force_float32=True,
                 player_reward_shaping=None,
                 observer_reward_shaping=None,
                 max_episode_length=np.inf,
                 add_channel=False):
        self.serial = False
        env = AtariEnv(game=env_name)
        env.metadata = None
        env.reward_range = None
        super().__init__(env)
        o = self.env.reset()
        self.max_episode_length = max_episode_length
        self.curr_episode_length = 0
        self.add_channel = add_channel
        o, r, d, info = self.env.step(self.env.action_space.sample())
        env_ = self.env
        time_limit = isinstance(self.env, TimeLimit)
        while not time_limit and hasattr(env_, "env"):
            env_ = env_.env
            time_limit = isinstance(env_, TimeLimit)
        if time_limit:
            info["timeout"] = False  # gym's TimeLimit.truncated invalid name.
        self.time_limit = time_limit
        self._action_space = GymSpaceWrapper(
            space=self.env.action_space,
            name="act",
            null_value=self.env.action_space.null_value(),
            force_float32=force_float32,
        )
        self._observation_space = GymSpaceWrapper(
            space=self.env.observation_space,
            name="obs",
            null_value=self.env.observation_space.null_value(),
            force_float32=force_float32,
        )
        del self.action_space
        del self.observation_space
        self.player_turn = False
        self.last_done = False
        self.last_reward = 0
        self.last_info = {}
        if player_reward_shaping is None:
            self.player_reward_shaping = reward_shaping_ph
        else:
            self.player_reward_shaping = player_reward_shaping
        if observer_reward_shaping is None:
            self.observer_reward_shaping = reward_shaping_ph
        else:
            self.observer_reward_shaping = observer_reward_shaping
        self.obs_size = self.env.observation_space.shape
        self.window_size = window_size
        self.obs_action_translator = obs_action_translator

        player_obs_space = self.env.observation_space
        if add_channel:
            player_obs_space = IntBox(low=player_obs_space.low,
                                      high=player_obs_space.high,
                                      shape=player_obs_space.shape,
                                      dtype=player_obs_space.dtype,
                                      null_value=player_obs_space.null_value())
        player_act_space = self.env.action_space
        observer_obs_space = self.env.observation_space
        observer_act_space = Box(low=np.asarray([0.0, 0.0]),
                                 high=np.asarray([
                                     self.env.observation_space.shape[0],
                                     self.env.observation_space.shape[1]
                                 ]))

        self.player_action_space = GymSpaceWrapper(
            space=player_act_space,
            name="act",
            null_value=player_act_space.null_value(),
            force_float32=force_float32)
        self.observer_action_space = GymSpaceWrapper(
            space=observer_act_space,
            name="act",
            null_value=np.zeros(2),
            force_float32=force_float32)
        self.player_observation_space = GymSpaceWrapper(
            space=player_obs_space,
            name="obs",
            null_value=player_obs_space.null_value(),
            force_float32=force_float32)
        self.observer_observation_space = GymSpaceWrapper(
            space=observer_obs_space,
            name="obs",
            null_value=observer_obs_space.null_value(),
            force_float32=force_float32)
示例#11
0
def build_and_train(game="cartpole",
                    run_ID=0,
                    cuda_idx=None,
                    sample_mode="serial",
                    n_parallel=2,
                    eval=False,
                    serial=False,
                    train_mask=[True, True],
                    wandb_log=False,
                    save_models_to_wandb=False,
                    log_interval_steps=1e5,
                    observation_mode="agent",
                    inc_player_last_act=False,
                    alt_train=False,
                    eval_perf=False,
                    n_steps=50e6,
                    one_agent=False):
    # def envs:
    if observation_mode == "agent":
        fully_obs = False
        rand_obs = False
    elif observation_mode == "random":
        fully_obs = False
        rand_obs = True
    elif observation_mode == "full":
        fully_obs = True
        rand_obs = False

    n_serial = None
    if game == "cartpole":
        work_env = gym.make
        env_name = 'CartPole-v1'
        cont_act = False
        state_space_low = np.asarray([
            0.0, 0.0, 0.0, 0.0, -4.8000002e+00, -3.4028235e+38, -4.1887903e-01,
            -3.4028235e+38
        ])
        state_space_high = np.asarray([
            1.0, 1.0, 1.0, 1.0, 4.8000002e+00, 3.4028235e+38, 4.1887903e-01,
            3.4028235e+38
        ])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = work_env(env_name).action_space
        player_act_space.shape = (1, )
        print(player_act_space)
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 1),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = player_reward_shaping_cartpole
        observer_reward_shaping = observer_reward_shaping_cartpole
        max_decor_steps = 20
        b_size = 20
        num_envs = 8
        max_episode_length = np.inf
        player_model_kwargs = dict(hidden_sizes=[24],
                                   lstm_size=16,
                                   nonlinearity=torch.nn.ReLU,
                                   normalize_observation=False,
                                   norm_obs_clip=10,
                                   norm_obs_var_clip=1e-6)
        observer_model_kwargs = dict(hidden_sizes=[64],
                                     lstm_size=16,
                                     nonlinearity=torch.nn.ReLU,
                                     normalize_observation=False,
                                     norm_obs_clip=10,
                                     norm_obs_var_clip=1e-6)

    elif game == "hiv":
        work_env = wn.gym.make
        env_name = 'HIV-v0'
        cont_act = False
        state_space_low = np.asarray(
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
        state_space_high = np.asarray([
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, np.inf, np.inf, np.inf, np.inf,
            np.inf, np.inf
        ])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = work_env(env_name).action_space
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 3),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = player_reward_shaping_hiv
        observer_reward_shaping = observer_reward_shaping_hiv
        max_decor_steps = 10
        b_size = 32
        num_envs = 8
        max_episode_length = 100
        player_model_kwargs = dict(hidden_sizes=[32],
                                   lstm_size=16,
                                   nonlinearity=torch.nn.ReLU,
                                   normalize_observation=False,
                                   norm_obs_clip=10,
                                   norm_obs_var_clip=1e-6)
        observer_model_kwargs = dict(hidden_sizes=[64],
                                     lstm_size=16,
                                     nonlinearity=torch.nn.ReLU,
                                     normalize_observation=False,
                                     norm_obs_clip=10,
                                     norm_obs_var_clip=1e-6)

    elif game == "heparin":
        work_env = HeparinEnv
        env_name = 'Heparin-Simulator'
        cont_act = False
        state_space_low = np.asarray([
            0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 18728.926, 72.84662, 0.0, 0.0,
            0.0, 0.0, 0.0
        ])
        state_space_high = np.asarray([
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.7251439e+04, 1.0664291e+02,
            200.0, 8.9383472e+02, 1.0025734e+02, 1.5770737e+01, 4.7767456e+01
        ])
        # state_space_low = np.asarray([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,18728.926,72.84662,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])
        # state_space_high = np.asarray([1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,2.7251439e+04,1.0664291e+02,0.0000000e+00,8.9383472e+02,1.4476662e+02,1.3368750e+02,1.6815166e+02,1.0025734e+02,1.5770737e+01,4.7767456e+01,7.7194958e+00])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = work_env(env_name).action_space
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 4),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = player_reward_shaping_hep
        observer_reward_shaping = observer_reward_shaping_hep
        max_decor_steps = 3
        b_size = 20
        num_envs = 8
        max_episode_length = 20
        player_model_kwargs = dict(hidden_sizes=[32],
                                   lstm_size=16,
                                   nonlinearity=torch.nn.ReLU,
                                   normalize_observation=False,
                                   norm_obs_clip=10,
                                   norm_obs_var_clip=1e-6)
        observer_model_kwargs = dict(hidden_sizes=[128],
                                     lstm_size=16,
                                     nonlinearity=torch.nn.ReLU,
                                     normalize_observation=False,
                                     norm_obs_clip=10,
                                     norm_obs_var_clip=1e-6)

    elif game == "halfcheetah":
        assert not serial
        assert not one_agent
        work_env = gym.make
        env_name = 'HalfCheetah-v2'
        cont_act = True
        temp_env = work_env(env_name)
        state_space_low = np.concatenate([
            np.zeros(temp_env.observation_space.low.shape),
            temp_env.observation_space.low
        ])
        state_space_high = np.concatenate([
            np.ones(temp_env.observation_space.high.shape),
            temp_env.observation_space.high
        ])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = temp_env.action_space
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 4),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = None
        observer_reward_shaping = None
        temp_env.close()
        max_decor_steps = 0
        b_size = 20
        num_envs = 8
        max_episode_length = np.inf
        player_model_kwargs = dict(hidden_sizes=[256, 256])
        observer_model_kwargs = dict(hidden_sizes=[256, 256])
        player_q_model_kwargs = dict(hidden_sizes=[256, 256])
        observer_q_model_kwargs = dict(hidden_sizes=[256, 256])
        player_v_model_kwargs = dict(hidden_sizes=[256, 256])
        observer_v_model_kwargs = dict(hidden_sizes=[256, 256])
    if game == "halfcheetah":
        observer_act_space = Box(
            low=state_space_low[:int(len(state_space_low) / 2)],
            high=state_space_high[:int(len(state_space_high) / 2)])
    else:
        if serial:
            n_serial = int(len(state_space_high) / 2)
            observer_act_space = Discrete(2)
            observer_act_space.shape = (1, )
        else:
            if one_agent:
                observer_act_space = IntBox(
                    low=0,
                    high=player_act_space.n *
                    int(2**int(len(state_space_high) / 2)))
            else:
                observer_act_space = IntBox(low=0,
                                            high=int(2**int(
                                                len(state_space_high) / 2)))

    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        alt = False
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        if eval:
            eval_collector_cl = SerialEvalCollector
        else:
            eval_collector_cl = None
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        alt = False
        Sampler = CpuSampler
        if eval:
            eval_collector_cl = CpuEvalCollector
        else:
            eval_collector_cl = None
        print(
            f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing."
        )
    env_kwargs = dict(work_env=work_env,
                      env_name=env_name,
                      obs_spaces=[obs_space, observer_obs_space],
                      action_spaces=[player_act_space, observer_act_space],
                      serial=serial,
                      player_reward_shaping=player_reward_shaping,
                      observer_reward_shaping=observer_reward_shaping,
                      fully_obs=fully_obs,
                      rand_obs=rand_obs,
                      inc_player_last_act=inc_player_last_act,
                      max_episode_length=max_episode_length,
                      cont_act=cont_act)
    if eval:
        eval_env_kwargs = env_kwargs
        eval_max_steps = 1e4
        num_eval_envs = num_envs
    else:
        eval_env_kwargs = None
        eval_max_steps = None
        num_eval_envs = 0
    sampler = Sampler(
        EnvCls=CWTO_EnvWrapper,
        env_kwargs=env_kwargs,
        batch_T=b_size,
        batch_B=num_envs,
        max_decorrelation_steps=max_decor_steps,
        eval_n_envs=num_eval_envs,
        eval_CollectorCls=eval_collector_cl,
        eval_env_kwargs=eval_env_kwargs,
        eval_max_steps=eval_max_steps,
    )
    if game == "halfcheetah":
        player_algo = SAC()
        observer_algo = SACBeta()
        player = SacAgent(ModelCls=PiMlpModel,
                          QModelCls=QofMuMlpModel,
                          model_kwargs=player_model_kwargs,
                          q_model_kwargs=player_q_model_kwargs,
                          v_model_kwargs=player_v_model_kwargs)
        observer = SacAgentBeta(ModelCls=PiMlpModelBeta,
                                QModelCls=QofMuMlpModel,
                                model_kwargs=observer_model_kwargs,
                                q_model_kwargs=observer_q_model_kwargs,
                                v_model_kwargs=observer_v_model_kwargs)
    else:
        player_model = CWTO_LstmModel
        observer_model = CWTO_LstmModel

        player_algo = PPO()
        observer_algo = PPO()
        player = CWTO_LstmAgent(ModelCls=player_model,
                                model_kwargs=player_model_kwargs,
                                initial_model_state_dict=None)
        observer = CWTO_LstmAgent(ModelCls=observer_model,
                                  model_kwargs=observer_model_kwargs,
                                  initial_model_state_dict=None)
    if one_agent:
        agent = CWTO_AgentWrapper(player,
                                  observer,
                                  serial=serial,
                                  n_serial=n_serial,
                                  alt=alt,
                                  train_mask=train_mask,
                                  one_agent=one_agent,
                                  nplayeract=player_act_space.n)
    else:
        agent = CWTO_AgentWrapper(player,
                                  observer,
                                  serial=serial,
                                  n_serial=n_serial,
                                  alt=alt,
                                  train_mask=train_mask)

    if eval:
        RunnerCl = MinibatchRlEval
    else:
        RunnerCl = MinibatchRl

    runner = RunnerCl(player_algo=player_algo,
                      observer_algo=observer_algo,
                      agent=agent,
                      sampler=sampler,
                      n_steps=n_steps,
                      log_interval_steps=log_interval_steps,
                      affinity=affinity,
                      wandb_log=wandb_log,
                      alt_train=alt_train)
    config = dict(domain=game)
    if game == "halfcheetah":
        name = "sac_" + game
    else:
        name = "ppo_" + game
    log_dir = os.getcwd() + "/cwto_logs/" + name
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
    if save_models_to_wandb:
        agent.save_models_to_wandb()
    if eval_perf:
        eval_n_envs = 10
        eval_envs = [CWTO_EnvWrapper(**env_kwargs) for _ in range(eval_n_envs)]
        set_envs_seeds(eval_envs, make_seed())
        eval_collector = SerialEvalCollector(envs=eval_envs,
                                             agent=agent,
                                             TrajInfoCls=TrajInfo_obs,
                                             max_T=1000,
                                             max_trajectories=10,
                                             log_full_obs=True)
        traj_infos_player, traj_infos_observer = eval_collector.collect_evaluation(
            runner.get_n_itr())
        observations = []
        player_actions = []
        returns = []
        observer_actions = []
        for traj in traj_infos_player:
            observations.append(np.array(traj.Observations))
            player_actions.append(np.array(traj.Actions))
            returns.append(traj.Return)
        for traj in traj_infos_observer:
            observer_actions.append(
                np.array([
                    obs_action_translator(act, eval_envs[0].power_vec,
                                          eval_envs[0].obs_size)
                    for act in traj.Actions
                ]))

        # save results:
        open_obs = open('eval_observations.pkl', "wb")
        pickle.dump(observations, open_obs)
        open_obs.close()
        open_ret = open('eval_returns.pkl', "wb")
        pickle.dump(returns, open_ret)
        open_ret.close()
        open_pact = open('eval_player_actions.pkl', "wb")
        pickle.dump(player_actions, open_pact)
        open_pact.close()
        open_oact = open('eval_observer_actions.pkl', "wb")
        pickle.dump(observer_actions, open_oact)
        open_oact.close()