Пример #1
0
def sampling_process(common_kwargs, worker_kwargs):
    """Target function used for forking parallel worker processes in the
    samplers. After ``initialize_worker()``, it creates the specified number
    of environment instances and gives them to the collector when
    instantiating it.  It then calls collector startup methods for
    environments and agent.  If applicable, instantiates evaluation
    environment instances and evaluation collector.

    Then enters infinite loop, waiting for signals from master to collect
    training samples or else run evaluation, until signaled to exit.
    """
    c, w = AttrDict(**common_kwargs), AttrDict(**worker_kwargs)
    initialize_worker(w.rank, w.seed, w.cpus, c.torch_threads)
    envs = [c.EnvCls(**c.env_kwargs) for _ in range(w.n_envs)]
    set_envs_seeds(envs, w.seed)

    collector = c.CollectorCls(
        rank=w.rank,
        envs=envs,
        samples_np=w.samples_np,
        batch_T=c.batch_T,
        TrajInfoCls=c.TrajInfoCls,
        agent=c.get("agent", None),  # Optional depending on parallel setup.
        sync=w.get("sync", None),
        step_buffer_np=w.get("step_buffer_np", None),
        global_B=c.get("global_B", 1),
        env_ranks=w.get("env_ranks", None),
    )
    agent_inputs, traj_infos = collector.start_envs(c.max_decorrelation_steps)
    collector.start_agent()

    if c.get("eval_n_envs", 0) > 0:
        eval_envs = [
            c.EnvCls(**c.eval_env_kwargs) for _ in range(c.eval_n_envs)
        ]
        set_envs_seeds(eval_envs, w.seed)
        eval_collector = c.eval_CollectorCls(
            rank=w.rank,
            envs=eval_envs,
            TrajInfoCls=c.TrajInfoCls,
            traj_infos_queue=c.eval_traj_infos_queue,
            max_T=c.eval_max_T,
            agent=c.get("agent", None),
            sync=w.get("sync", None),
            step_buffer_np=w.get("eval_step_buffer_np", None),
        )
    else:
        eval_envs = list()

    ctrl = c.ctrl
    ctrl.barrier_out.wait()
    while True:
        collector.reset_if_needed(agent_inputs)  # Outside barrier?
        ctrl.barrier_in.wait()
        if ctrl.quit.value:
            break
        if ctrl.do_eval.value:
            # Traj_infos to queue inside.
            eval_collector.collect_evaluation(ctrl.itr.value)
        else:
            agent_inputs, traj_infos, completed_infos = collector.collect_batch(
                agent_inputs, traj_infos, ctrl.itr.value)
            for info in completed_infos:
                c.traj_infos_queue.put(info)
        ctrl.barrier_out.wait()

    for env in envs + eval_envs:
        env.close()
Пример #2
0
def sampling_process(common_kwargs, worker_kwargs):
    """Target function used for forking parallel worker processes in the
    samplers. After ``initialize_worker()``, it creates the specified number
    of environment instances and gives them to the collector when
    instantiating it.  It then calls collector startup methods for
    environments and agent.  If applicable, instantiates evaluation
    environment instances and evaluation collector.

    Then enters infinite loop, waiting for signals from master to collect
    training samples or else run evaluation, until signaled to exit.
    """
    c, w = AttrDict(**common_kwargs), AttrDict(**worker_kwargs)
    initialize_worker(w.rank, w.seed, w.cpus, c.torch_threads)

    envs = [c.EnvCls(**c.env_kwargs) for _ in range(w.n_envs)]

    log_heatmaps = c.env_kwargs.get('log_heatmaps', None)

    if log_heatmaps is not None and log_heatmaps == True:
        for env in envs[1:]:
            env.log_heatmaps = False

    if c.record_freq > 0:
        if c.env_kwargs['game'] in ATARI_ENVS:
            envs[0].record_env = True
            os.makedirs(os.path.join(c.log_dir, 'videos/frames'))
        elif c.get(
                "eval_n_envs", 0
        ) == 0:  # only record workers if no evaluation processes are performed
            envs[0] = Monitor(envs[0],
                              c.log_dir + '/videos',
                              video_callable=lambda episode_id: episode_id % c.
                              record_freq == 0)

    set_envs_seeds(envs, w.seed)

    collector = c.CollectorCls(
        rank=w.rank,
        envs=envs,
        samples_np=w.samples_np,
        batch_T=c.batch_T,
        TrajInfoCls=c.TrajInfoCls,
        agent=c.get("agent", None),  # Optional depending on parallel setup.
        sync=w.get("sync", None),
        step_buffer_np=w.get("step_buffer_np", None),
        global_B=c.get("global_B", 1),
        env_ranks=w.get("env_ranks", None),
        no_extrinsic=c.no_extrinsic)
    agent_inputs, traj_infos = collector.start_envs(c.max_decorrelation_steps)
    collector.start_agent()

    if c.get("eval_n_envs", 0) > 0:
        eval_envs = [
            c.EnvCls(**c.eval_env_kwargs) for _ in range(c.eval_n_envs)
        ]
        if c.record_freq > 0:
            eval_envs[0] = Monitor(eval_envs[0],
                                   c.log_dir + '/videos',
                                   video_callable=lambda episode_id: episode_id
                                   % c.record_freq == 0)
        set_envs_seeds(eval_envs, w.seed)
        eval_collector = c.eval_CollectorCls(
            rank=w.rank,
            envs=eval_envs,
            TrajInfoCls=c.TrajInfoCls,
            traj_infos_queue=c.eval_traj_infos_queue,
            max_T=c.eval_max_T,
            agent=c.get("agent", None),
            sync=w.get("sync", None),
            step_buffer_np=w.get("eval_step_buffer_np", None),
        )
    else:
        eval_envs = list()

    ctrl = c.ctrl
    ctrl.barrier_out.wait()
    while True:
        collector.reset_if_needed(agent_inputs)  # Outside barrier?
        ctrl.barrier_in.wait()
        if ctrl.quit.value:
            logger.log('Quitting worker ...')
            break
        if ctrl.do_eval.value:
            eval_collector.collect_evaluation(
                ctrl.itr.value)  # Traj_infos to queue inside.
        else:
            agent_inputs, traj_infos, completed_infos = collector.collect_batch(
                agent_inputs, traj_infos, ctrl.itr.value)
            for info in completed_infos:
                c.traj_infos_queue.put(info)
        ctrl.barrier_out.wait()

    for env in envs + eval_envs:
        logger.log('Stopping env ...')
        env.close()
Пример #3
0
    def initialize(
            self,
            agent,
            affinity=None,
            seed=None,
            bootstrap_value=False,
            traj_info_kwargs=None,
            rank=0,
            world_size=1,
            ):
        """Store the input arguments.  Instantiate the specified number of environment
        instances (``batch_B``).  Initialize the agent, and pre-allocate a memory buffer
        to hold the samples collected in each batch.  Applies ``traj_info_kwargs`` settings
        to the `TrajInfoCls` by direct class attribute assignment.  Instantiates the Collector
        and, if applicable, the evaluation Collector.

        Returns a structure of inidividual examples for data fields such as `observation`,
        `action`, etc, which can be used to allocate a replay buffer.
        """
        B = self.batch_spec.B
        envs = [self.EnvCls(**self.env_kwargs) for _ in range(B)]

        set_envs_seeds(envs, seed)  # Random seed made in runner.

        global_B = B * world_size
        env_ranks = list(range(rank * B, (rank + 1) * B))
        agent.observer.initialize(envs[0].spaces()[1], share_memory=False,
            global_B=global_B, env_ranks=env_ranks)
        agent.player.initialize(envs[0].spaces()[0], share_memory=False,
                         global_B=global_B, env_ranks=env_ranks)

        observer_samples_pyt, observer_samples_np, observer_examples = build_samples_buffer(agent.observer, envs[0],
            self.batch_spec, bootstrap_value, agent_shared=False,
            env_shared=False, subprocess=False)

        player_samples_pyt, player_samples_np, player_examples = build_samples_buffer(agent.player, envs[0],
            self.batch_spec, bootstrap_value, agent_shared=False,
            env_shared=False, subprocess=False)
        if traj_info_kwargs:
            for k, v in traj_info_kwargs.items():
                setattr(self.TrajInfoCls, "_" + k, v)  # Avoid passing at init.
        collector = self.CollectorCls(
            rank=0,
            envs=envs,
            player_samples_np=player_samples_np,
            observer_samples_np=observer_samples_np,
            batch_T=self.batch_spec.T,
            TrajInfoCls=self.TrajInfoCls,
            agent=agent,
            global_B=global_B,
            env_ranks=env_ranks,  # Might get applied redundantly to agent.
        )
        if self.eval_n_envs > 0:  # May do evaluation.
            eval_envs = [self.EnvCls(**self.eval_env_kwargs)
                for _ in range(self.eval_n_envs)]
            set_envs_seeds(eval_envs, seed)
            eval_CollectorCls = self.eval_CollectorCls or SerialEvalCollector
            self.eval_collector = eval_CollectorCls(
                envs=eval_envs,
                agent=agent,
                TrajInfoCls=self.TrajInfoCls,
                max_T=self.eval_max_steps // self.eval_n_envs,
                max_trajectories=self.eval_max_trajectories,
            )

        player_agent_inputs, player_traj_infos, observer_agent_inputs, observer_traj_infos = collector.start_envs(
            self.max_decorrelation_steps)
        collector.start_agent()

        self.agent = agent
        self.player_samples_pyt = player_samples_pyt
        self.player_samples_np = player_samples_np
        self.observer_samples_pyt = observer_samples_pyt
        self.observer_samples_np = observer_samples_np
        self.collector = collector
        self.player_agent_inputs = player_agent_inputs
        self.player_traj_infos = player_traj_infos
        self.observer_agent_inputs = observer_agent_inputs
        self.observer_traj_infos = observer_traj_infos
        logger.log("Serial Sampler initialized.")
        return player_examples, observer_examples
Пример #4
0
def build_and_train(game="cartpole",
                    run_ID=0,
                    cuda_idx=None,
                    sample_mode="serial",
                    n_parallel=2,
                    eval=False,
                    serial=False,
                    train_mask=[True, True],
                    wandb_log=False,
                    save_models_to_wandb=False,
                    log_interval_steps=1e5,
                    observation_mode="agent",
                    inc_player_last_act=False,
                    alt_train=False,
                    eval_perf=False,
                    n_steps=50e6,
                    one_agent=False):
    # def envs:
    if observation_mode == "agent":
        fully_obs = False
        rand_obs = False
    elif observation_mode == "random":
        fully_obs = False
        rand_obs = True
    elif observation_mode == "full":
        fully_obs = True
        rand_obs = False

    n_serial = None
    if game == "cartpole":
        work_env = gym.make
        env_name = 'CartPole-v1'
        cont_act = False
        state_space_low = np.asarray([
            0.0, 0.0, 0.0, 0.0, -4.8000002e+00, -3.4028235e+38, -4.1887903e-01,
            -3.4028235e+38
        ])
        state_space_high = np.asarray([
            1.0, 1.0, 1.0, 1.0, 4.8000002e+00, 3.4028235e+38, 4.1887903e-01,
            3.4028235e+38
        ])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = work_env(env_name).action_space
        player_act_space.shape = (1, )
        print(player_act_space)
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 1),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = player_reward_shaping_cartpole
        observer_reward_shaping = observer_reward_shaping_cartpole
        max_decor_steps = 20
        b_size = 20
        num_envs = 8
        max_episode_length = np.inf
        player_model_kwargs = dict(hidden_sizes=[24],
                                   lstm_size=16,
                                   nonlinearity=torch.nn.ReLU,
                                   normalize_observation=False,
                                   norm_obs_clip=10,
                                   norm_obs_var_clip=1e-6)
        observer_model_kwargs = dict(hidden_sizes=[64],
                                     lstm_size=16,
                                     nonlinearity=torch.nn.ReLU,
                                     normalize_observation=False,
                                     norm_obs_clip=10,
                                     norm_obs_var_clip=1e-6)

    elif game == "hiv":
        work_env = wn.gym.make
        env_name = 'HIV-v0'
        cont_act = False
        state_space_low = np.asarray(
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
        state_space_high = np.asarray([
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, np.inf, np.inf, np.inf, np.inf,
            np.inf, np.inf
        ])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = work_env(env_name).action_space
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 3),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = player_reward_shaping_hiv
        observer_reward_shaping = observer_reward_shaping_hiv
        max_decor_steps = 10
        b_size = 32
        num_envs = 8
        max_episode_length = 100
        player_model_kwargs = dict(hidden_sizes=[32],
                                   lstm_size=16,
                                   nonlinearity=torch.nn.ReLU,
                                   normalize_observation=False,
                                   norm_obs_clip=10,
                                   norm_obs_var_clip=1e-6)
        observer_model_kwargs = dict(hidden_sizes=[64],
                                     lstm_size=16,
                                     nonlinearity=torch.nn.ReLU,
                                     normalize_observation=False,
                                     norm_obs_clip=10,
                                     norm_obs_var_clip=1e-6)

    elif game == "heparin":
        work_env = HeparinEnv
        env_name = 'Heparin-Simulator'
        cont_act = False
        state_space_low = np.asarray([
            0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 18728.926, 72.84662, 0.0, 0.0,
            0.0, 0.0, 0.0
        ])
        state_space_high = np.asarray([
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.7251439e+04, 1.0664291e+02,
            200.0, 8.9383472e+02, 1.0025734e+02, 1.5770737e+01, 4.7767456e+01
        ])
        # state_space_low = np.asarray([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,18728.926,72.84662,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])
        # state_space_high = np.asarray([1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,2.7251439e+04,1.0664291e+02,0.0000000e+00,8.9383472e+02,1.4476662e+02,1.3368750e+02,1.6815166e+02,1.0025734e+02,1.5770737e+01,4.7767456e+01,7.7194958e+00])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = work_env(env_name).action_space
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 4),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = player_reward_shaping_hep
        observer_reward_shaping = observer_reward_shaping_hep
        max_decor_steps = 3
        b_size = 20
        num_envs = 8
        max_episode_length = 20
        player_model_kwargs = dict(hidden_sizes=[32],
                                   lstm_size=16,
                                   nonlinearity=torch.nn.ReLU,
                                   normalize_observation=False,
                                   norm_obs_clip=10,
                                   norm_obs_var_clip=1e-6)
        observer_model_kwargs = dict(hidden_sizes=[128],
                                     lstm_size=16,
                                     nonlinearity=torch.nn.ReLU,
                                     normalize_observation=False,
                                     norm_obs_clip=10,
                                     norm_obs_var_clip=1e-6)

    elif game == "halfcheetah":
        assert not serial
        assert not one_agent
        work_env = gym.make
        env_name = 'HalfCheetah-v2'
        cont_act = True
        temp_env = work_env(env_name)
        state_space_low = np.concatenate([
            np.zeros(temp_env.observation_space.low.shape),
            temp_env.observation_space.low
        ])
        state_space_high = np.concatenate([
            np.ones(temp_env.observation_space.high.shape),
            temp_env.observation_space.high
        ])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = temp_env.action_space
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 4),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = None
        observer_reward_shaping = None
        temp_env.close()
        max_decor_steps = 0
        b_size = 20
        num_envs = 8
        max_episode_length = np.inf
        player_model_kwargs = dict(hidden_sizes=[256, 256])
        observer_model_kwargs = dict(hidden_sizes=[256, 256])
        player_q_model_kwargs = dict(hidden_sizes=[256, 256])
        observer_q_model_kwargs = dict(hidden_sizes=[256, 256])
        player_v_model_kwargs = dict(hidden_sizes=[256, 256])
        observer_v_model_kwargs = dict(hidden_sizes=[256, 256])
    if game == "halfcheetah":
        observer_act_space = Box(
            low=state_space_low[:int(len(state_space_low) / 2)],
            high=state_space_high[:int(len(state_space_high) / 2)])
    else:
        if serial:
            n_serial = int(len(state_space_high) / 2)
            observer_act_space = Discrete(2)
            observer_act_space.shape = (1, )
        else:
            if one_agent:
                observer_act_space = IntBox(
                    low=0,
                    high=player_act_space.n *
                    int(2**int(len(state_space_high) / 2)))
            else:
                observer_act_space = IntBox(low=0,
                                            high=int(2**int(
                                                len(state_space_high) / 2)))

    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        alt = False
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        if eval:
            eval_collector_cl = SerialEvalCollector
        else:
            eval_collector_cl = None
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        alt = False
        Sampler = CpuSampler
        if eval:
            eval_collector_cl = CpuEvalCollector
        else:
            eval_collector_cl = None
        print(
            f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing."
        )
    env_kwargs = dict(work_env=work_env,
                      env_name=env_name,
                      obs_spaces=[obs_space, observer_obs_space],
                      action_spaces=[player_act_space, observer_act_space],
                      serial=serial,
                      player_reward_shaping=player_reward_shaping,
                      observer_reward_shaping=observer_reward_shaping,
                      fully_obs=fully_obs,
                      rand_obs=rand_obs,
                      inc_player_last_act=inc_player_last_act,
                      max_episode_length=max_episode_length,
                      cont_act=cont_act)
    if eval:
        eval_env_kwargs = env_kwargs
        eval_max_steps = 1e4
        num_eval_envs = num_envs
    else:
        eval_env_kwargs = None
        eval_max_steps = None
        num_eval_envs = 0
    sampler = Sampler(
        EnvCls=CWTO_EnvWrapper,
        env_kwargs=env_kwargs,
        batch_T=b_size,
        batch_B=num_envs,
        max_decorrelation_steps=max_decor_steps,
        eval_n_envs=num_eval_envs,
        eval_CollectorCls=eval_collector_cl,
        eval_env_kwargs=eval_env_kwargs,
        eval_max_steps=eval_max_steps,
    )
    if game == "halfcheetah":
        player_algo = SAC()
        observer_algo = SACBeta()
        player = SacAgent(ModelCls=PiMlpModel,
                          QModelCls=QofMuMlpModel,
                          model_kwargs=player_model_kwargs,
                          q_model_kwargs=player_q_model_kwargs,
                          v_model_kwargs=player_v_model_kwargs)
        observer = SacAgentBeta(ModelCls=PiMlpModelBeta,
                                QModelCls=QofMuMlpModel,
                                model_kwargs=observer_model_kwargs,
                                q_model_kwargs=observer_q_model_kwargs,
                                v_model_kwargs=observer_v_model_kwargs)
    else:
        player_model = CWTO_LstmModel
        observer_model = CWTO_LstmModel

        player_algo = PPO()
        observer_algo = PPO()
        player = CWTO_LstmAgent(ModelCls=player_model,
                                model_kwargs=player_model_kwargs,
                                initial_model_state_dict=None)
        observer = CWTO_LstmAgent(ModelCls=observer_model,
                                  model_kwargs=observer_model_kwargs,
                                  initial_model_state_dict=None)
    if one_agent:
        agent = CWTO_AgentWrapper(player,
                                  observer,
                                  serial=serial,
                                  n_serial=n_serial,
                                  alt=alt,
                                  train_mask=train_mask,
                                  one_agent=one_agent,
                                  nplayeract=player_act_space.n)
    else:
        agent = CWTO_AgentWrapper(player,
                                  observer,
                                  serial=serial,
                                  n_serial=n_serial,
                                  alt=alt,
                                  train_mask=train_mask)

    if eval:
        RunnerCl = MinibatchRlEval
    else:
        RunnerCl = MinibatchRl

    runner = RunnerCl(player_algo=player_algo,
                      observer_algo=observer_algo,
                      agent=agent,
                      sampler=sampler,
                      n_steps=n_steps,
                      log_interval_steps=log_interval_steps,
                      affinity=affinity,
                      wandb_log=wandb_log,
                      alt_train=alt_train)
    config = dict(domain=game)
    if game == "halfcheetah":
        name = "sac_" + game
    else:
        name = "ppo_" + game
    log_dir = os.getcwd() + "/cwto_logs/" + name
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
    if save_models_to_wandb:
        agent.save_models_to_wandb()
    if eval_perf:
        eval_n_envs = 10
        eval_envs = [CWTO_EnvWrapper(**env_kwargs) for _ in range(eval_n_envs)]
        set_envs_seeds(eval_envs, make_seed())
        eval_collector = SerialEvalCollector(envs=eval_envs,
                                             agent=agent,
                                             TrajInfoCls=TrajInfo_obs,
                                             max_T=1000,
                                             max_trajectories=10,
                                             log_full_obs=True)
        traj_infos_player, traj_infos_observer = eval_collector.collect_evaluation(
            runner.get_n_itr())
        observations = []
        player_actions = []
        returns = []
        observer_actions = []
        for traj in traj_infos_player:
            observations.append(np.array(traj.Observations))
            player_actions.append(np.array(traj.Actions))
            returns.append(traj.Return)
        for traj in traj_infos_observer:
            observer_actions.append(
                np.array([
                    obs_action_translator(act, eval_envs[0].power_vec,
                                          eval_envs[0].obs_size)
                    for act in traj.Actions
                ]))

        # save results:
        open_obs = open('eval_observations.pkl', "wb")
        pickle.dump(observations, open_obs)
        open_obs.close()
        open_ret = open('eval_returns.pkl', "wb")
        pickle.dump(returns, open_ret)
        open_ret.close()
        open_pact = open('eval_player_actions.pkl', "wb")
        pickle.dump(player_actions, open_pact)
        open_pact.close()
        open_oact = open('eval_observer_actions.pkl', "wb")
        pickle.dump(observer_actions, open_oact)
        open_oact.close()