Beispiel #1
0
def test_ppo():
    from toolbox.utils import initialize_ray
    initialize_ray(test_mode=True)
    po = PPOAgentWithActivation(
        env="BipedalWalker-v2", config=fc_with_activation_model_config
    )
    return po
Beispiel #2
0
def train_one_iteration(
        iter_id,
        exp_name,
        init_yaml_path,
        config,
        stop_criterion,
        num_seeds=1,
        num_gpus=0,
        test_mode=False
):
    assert isinstance(iter_id, int)
    assert isinstance(exp_name, str)
    assert isinstance(stop_criterion, dict)
    assert isinstance(init_yaml_path, str)
    assert osp.exists(init_yaml_path)

    local_dir = get_local_dir() if get_local_dir() else "~/ray_results"
    local_dir = os.path.expanduser(local_dir)
    save_path = os.path.join(local_dir, exp_name)
    current_yaml_path = init_yaml_path

    assert 'seed' not in exp_name, exp_name
    assert 'iter' not in exp_name, exp_name

    for i in range(num_seeds):
        input_exp_name = exp_name + "_seed{}_iter{}".format(i, iter_id)

        tmp_config = copy.deepcopy(config)
        tmp_config.update(seed=i)
        tmp_config['env_config']['yaml_path'] = current_yaml_path
        initialize_ray(
            num_gpus=num_gpus, test_mode=test_mode, local_mode=test_mode
        )
        tune.run(
            "PPO",
            name=input_exp_name,
            verbose=2 if test_mode else 1,
            local_dir=save_path,
            checkpoint_freq=10,
            checkpoint_at_end=True,
            stop=stop_criterion,
            config=tmp_config
        )

        name_ckpt_mapping = read_yaml(current_yaml_path)
        ckpt_path = _search_ckpt(save_path, input_exp_name)

        last_ckpt_dict = copy.deepcopy(list(name_ckpt_mapping.values())[-1])
        assert isinstance(last_ckpt_dict, dict), last_ckpt_dict
        assert 'path' in last_ckpt_dict, last_ckpt_dict
        last_ckpt_dict.update(path=ckpt_path)

        print("Finish the current last_ckpt_dict: ", last_ckpt_dict)
        name_ckpt_mapping[input_exp_name] = last_ckpt_dict

        current_yaml_path = osp.join(save_path, "post_agent_ppo.yaml")
        out = save_yaml(name_ckpt_mapping, current_yaml_path)
        assert out == current_yaml_path
Beispiel #3
0
 def __init__(self,
              video_path,
              local_mode=False,
              fps=50,
              require_full_frame=False):
     initialize_ray(local_mode)
     self.video_path = video_path
     self.fps = fps
     self.require_full_frame = require_full_frame
Beispiel #4
0
def test_generate_trailer_from_agent():
    initialize_ray(test_mode=True)
    agent = get_ppo_agent("BipedalWalker-v2")
    ret = generate_trailer_from_agent(
        agent, "test_agent", "/tmp/single_video", _steps=None
    )
    print(ret)
    # agent = get_ppo_agent_with_mask("BipedalWalker-v2")
    # ret = generate_gif_from_agent(agent, "test_agent", "/tmp/test_genrate_gif_with_mask", _steps=50)
    # print(ret)
    return ret
Beispiel #5
0
def test_generate_gif_from_agent():
    initialize_ray(test_mode=True)
    agent = get_ppo_agent("BipedalWalker-v2")
    ret = generate_gif_from_agent(
        agent, "test_agent", "/tmp/test_genrate", _steps=50
    )
    print(ret)
    agent = get_ppo_agent_with_mask("BipedalWalker-v2")
    ret = generate_gif_from_agent(
        agent, "test_agent", "/tmp/test_genrate_gif_with_mask", _steps=50
    )
    print(ret)
    return ret
Beispiel #6
0
def test_efficient_rollout():
    initialize_ray(num_gpus=4, test_mode=False)
    num_rollouts = 2
    num_workers = 2
    # yaml_path = "data/0902-ppo-20-agents/0902-ppo-20-agents.yaml"
    yaml_path = "data/0811-random-test.yaml"
    ret = get_fft_cluster_finder(
        yaml_path=yaml_path,
        num_rollouts=num_rollouts,
        num_workers=num_workers,
        num_gpus=4
    )
    return ret
Beispiel #7
0
def test_marl_individual_ppo(extra_config, local_mode=True, test_mode=True):
    num_gpus = 0
    exp_name = "test_marl_individual_ppo"
    env_name = "BipedalWalker-v2"
    num_iters = 50
    num_agents = 8

    initialize_ray(
        test_mode=test_mode, num_gpus=num_gpus, local_mode=local_mode
    )

    tmp_env = get_env_maker(env_name)()

    default_policy = (
        None, tmp_env.observation_space, tmp_env.action_space, {}
    )

    policy_names = ["ppo_agent{}".format(i) for i in range(num_agents)]

    def policy_mapping_fn(aid):
        # print("input aid: ", aid)
        return aid

    config = {
        "env": MultiAgentEnvWrapper,
        "env_config": {
            "env_name": env_name,
            "agent_ids": policy_names
        },
        "log_level": "DEBUG",
        "num_gpus": num_gpus,
        "multiagent": {
            "policies": {i: default_policy
                         for i in policy_names},
            "policy_mapping_fn": policy_mapping_fn,
        },
    }

    if isinstance(extra_config, dict):
        config.update(extra_config)

    tune.run(
        "PPO",
        local_dir=get_local_dir(),
        name=exp_name,
        checkpoint_at_end=True,
        checkpoint_freq=10,
        stop={"training_iteration": num_iters},
        config=config,
    )
Beispiel #8
0
def test_generate_gif_from_agent_mujoco_environemnt():
    initialize_ray(test_mode=True)
    agent = get_ppo_agent("HalfCheetah-v2")
    output_path = "(delete-me!)test_genrate_gif_mujoco"
    agent_name = "test_agent_mujoco"

    gvr = GridVideoRecorder(
        video_path=output_path, fps=FPS, require_full_frame=True
    )
    frames_dict, extra_info_dict = gvr.generate_frames_from_agent(
        agent, agent_name
    )

    name_path_dict = gvr.generate_gif(frames_dict, extra_info_dict)
    print("Gif has been saved at: ", name_path_dict)
Beispiel #9
0
def test():
    from collections import OrderedDict
    from toolbox.evaluate import MaskSymbolicAgent
    from toolbox.utils import initialize_ray, get_random_string
    import shutil

    initialize_ray(test_mode=True)

    print("Finish init")

    num_workers = 4
    num_agents = 8
    base_output_path = "/tmp/generate_trailer"
    ckpt = {
        "path":
        "~/ray_results/0810-20seeds/PPO_BipedalWalker-v2_"
        "0_seed=20_2019-08-10_16-54-37xaa2muqm/"
        "checkpoint_469/checkpoint-469",
        "env_name":
        "BipedalWalker-v2",
        "run_name":
        "PPO"
    }

    shutil.rmtree(base_output_path, ignore_errors=True)

    master_agents = OrderedDict()

    for _ in range(num_agents):
        ckpt['name'] = get_random_string()
        agent = MaskSymbolicAgent(ckpt)
        master_agents[ckpt['name']] = agent

    print("Master agents: ", master_agents)

    rsavm = RemoteSymbolicAgentVideoManager(num_workers, len(master_agents))

    for name, symbolic_agent in master_agents.items():
        rsavm.generate_video(name, symbolic_agent, base_output_path)

    result = rsavm.get_result()
    print(result)

    return result
Beispiel #10
0
def test_generate_gif_from_restored_agent_mujoco_environemnt():
    initialize_ray(test_mode=True)
    # agent = get_ppo_agent("HalfCheetah-v2")

    ckpt = "/home/zhpeng/ray_results/0915-hc-ppo-5-agents/" \
           "PPO_HalfCheetah-v2_2_seed=2_2019-09-15_15-01-01hyqn2x2v/" \
           "checkpoint_1060/checkpoint-1060"

    agent = restore_agent("PPO", ckpt, "HalfCheetah-v2")

    output_path = "(delete-me!)test_genrate_gif_mujoco"
    agent_name = "test_agent_mujoco"

    gvr = GridVideoRecorder(
        video_path=output_path, fps=FPS, require_full_frame=True
    )
    frames_dict, extra_info_dict = gvr.generate_frames_from_agent(
        agent, agent_name
    )

    name_path_dict = gvr.generate_gif(frames_dict, extra_info_dict)
    print("Gif has been saved at: ", name_path_dict)
Beispiel #11
0
def test_heavy_memory_usage():
    initialize_ray(test_mode=True, object_store_memory=400 * MB)
    # initialize_ray(test_mode=True)

    num = 50
    delay = 0
    num_workers = 16

    class TestWorker(WorkerBase):
        def __init__(self):
            self.count = 0

        def run(self):
            time.sleep(delay)
            self.count += 1
            print(self.count, ray.cluster_resources())
            arr = np.empty((10 * MB), dtype=np.uint8)
            print("in worker getsize", getsizeof(arr))
            # arr = 1
            return self.count, arr
            # oid = ray.put((self.count, arr))
            # return oid

    class TestManager(WorkerManagerBase):
        def __init__(self):
            super(TestManager, self).__init__(num_workers, TestWorker, num, 1,
                                              'test')

        def count(self, index):
            self.submit(index)

    tm = TestManager()
    for i in range(num):
        tm.count(i)

    ret = tm.get_result()
    return ret
Beispiel #12
0
                     ] = run_specify_stop["off_policy_tnb"]
    run_specify_stop["on_policy_tnb_min_novelty"
                     ] = run_specify_stop["off_policy_tnb"]
    run_specify_stop["on_policy_tnb"] = run_specify_stop["off_policy_tnb"]
    run_specify_stop["tnb_4in1"] = run_specify_stop["off_policy_tnb"]
    run_specify_stop["adaptive_extra_loss"] = run_specify_stop["extra_loss"]
    run_specify_stop["smart_adaptive_extra_loss"
                     ] = run_specify_stop["extra_loss"]

    assert run_name in run_dict, "--run argument should be in {}, " \
                                 "but you provide {}." \
                                 "".format(run_dict.keys(), run_name)

    initialize_ray(
        num_gpus=args.num_gpus,
        test_mode=args.test_mode,
        object_store_memory=40 * 1024 * 1024 * 1024,
        # temp_dir="/data1/pengzh/tmp"
    )

    policy_names = ["ppo_agent{}".format(i) for i in range(args.num_agents)]

    tmp_env = get_env_maker(env_name)()
    default_policy = (
        None, tmp_env.observation_space, tmp_env.action_space, {}
    )

    config = {
        "env": MultiAgentEnvWrapper,
        "env_config": {
            "env_name": env_name,
            "agent_ids": policy_names
Beispiel #13
0
    "env":
    args.env if args.env != "atari" else tune.grid_search([
        "BreakoutNoFrameskip-v4", "BeamRiderNoFrameskip-v4",
        "QbertNoFrameskip-v4", "SpaceInvadersNoFrameskip-v4"
    ]),
    "num_gpus":
    0.15 if 0.15 < args.num_gpus else 0,
    "num_cpus_for_driver":
    0.2,
    "num_cpus_per_worker":
    0.75
}

run_config = merge_dicts(general_config, algo_specify_config['config'])

initialize_ray(num_gpus=args.num_gpus, test_mode=args.test_mode)

assert args.run == "PPO"
register_mixture_action_distribution()

run_config["model"] = {
    "custom_action_dist": GaussianMixture.name,
    "custom_options": {
        "num_components": tune.grid_search([1, 2, 3, 5, 10])
    }
}

assert run_config['model']['custom_options']

tune.run(
    PPOTrainerWithoutKL,
Beispiel #14
0
            # for obj_id in obj_ids:
            #     trajectory_list.append(ray.get(obj_id))
            return_dict[name] = trajectory_list
            # worker.close.remote()
            print("[{}/{}] (+{:.1f}s/{:.1f}s) Collected {} rollouts from agent"
                  " <{}>".format(agent_count_get, num_agents,
                                 time.time() - now_t_get,
                                 time.time() - start_t, num_rollouts, name))
            agent_count_get += 1
            now_t_get = time.time()
    return return_dict if return_data else None


if __name__ == "__main__":
    # test_serveral_agent_rollout(True)
    # exit(0)
    # _test_es_agent_compatibility()
    parser = argparse.ArgumentParser()
    parser.add_argument("--yaml-path", required=True, type=str)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--num-rollouts", '-n', type=int, default=100)
    parser.add_argument("--num-workers", type=int, default=10)
    parser.add_argument("--num-gpus", '-g', type=int, default=4)
    parser.add_argument("--force-rewrite", action="store_true")
    args = parser.parse_args()
    assert args.yaml_path.endswith("yaml")

    initialize_ray(num_gpus=args.num_gpus)
    several_agent_rollout(args.yaml_path, args.num_rollouts, args.seed,
                          args.num_workers, args.force_rewrite)
Beispiel #15
0
def ir():
    initialize_ray(test_mode=False, num_gpus=2)
Beispiel #16
0
from ray import tune

from toolbox.env import get_env_maker
from toolbox.marl import MultiAgentEnvWrapper, on_train_result
from toolbox.marl.smart_adaptive_extra_loss import \
    SmartAdaptiveExtraLossPPOTrainer
from toolbox.utils import get_local_dir, initialize_ray

if __name__ == '__main__':
    num_gpus = 4
    num_agents = 10
    env_name = "HumanoidBulletEnv-v0"
    num_seeds = 1
    exp_name = "1130-SAEL-humanoid"

    initialize_ray(num_gpus=num_gpus,
                   object_store_memory=40 * 1024 * 1024 * 1024)
    policy_names = ["ppo_agent{}".format(i) for i in range(num_agents)]
    tmp_env = get_env_maker(env_name)()
    default_policy = (None, tmp_env.observation_space, tmp_env.action_space,
                      {})

    config = {
        # This experiment specify config
        # "performance_evaluation_metric": tune.grid_search(['max', 'mean']),
        "use_joint_dataset": tune.grid_search([True, False]),
        "novelty_mode": tune.grid_search(['min', 'mean']),

        # Humanoid Specify parameter
        "gamma": 0.995,
        "lambda": 0.95,
        "clip_param": 0.2,
Beispiel #17
0
def train(trainer,
          config,
          stop,
          exp_name,
          num_seeds=1,
          num_gpus=0,
          test_mode=False,
          suffix="",
          checkpoint_freq=10,
          keep_checkpoints_num=None,
          start_seed=0,
          **kwargs):
    # initialize ray
    if not os.environ.get("redis_password"):
        initialize_ray(test_mode=test_mode,
                       local_mode=False,
                       num_gpus=num_gpus)
    else:
        password = os.environ.get("redis_password")
        assert os.environ.get("ip_head")
        print("We detect redis_password ({}) exists in environment! So "
              "we will start a ray cluster!".format(password))
        if num_gpus:
            print("We are in cluster mode! So GPU specification is disable and"
                  " should be done when submitting task to cluster! You are "
                  "requiring {} GPU for each machine!".format(num_gpus))
        initialize_ray(address=os.environ["ip_head"],
                       test_mode=test_mode,
                       redis_password=password)

    # prepare config
    used_config = {
        "seed":
        tune.grid_search([i * 100 + start_seed for i in range(num_seeds)]),
        "log_level": "DEBUG" if test_mode else "INFO"
    }
    if config:
        used_config.update(config)
    config = copy.deepcopy(used_config)

    env_name = _get_env_name(config)

    trainer_name = trainer if isinstance(trainer, str) else trainer._name

    assert isinstance(env_name, str) or isinstance(env_name, list)
    if isinstance(env_name, str):
        env_names = [env_name]
    else:
        env_names = env_name
    for e in env_names:
        register_bullet(e)
        register_minigrid(e)

    if not isinstance(stop, dict):
        assert np.isscalar(stop)
        stop = {"timesteps_total": int(stop)}

    if keep_checkpoints_num is not None and not test_mode:
        assert isinstance(keep_checkpoints_num, int)
        kwargs["keep_checkpoints_num"] = keep_checkpoints_num
        kwargs["checkpoint_score_attr"] = "episode_reward_mean"

    if "verbose" not in kwargs:
        kwargs["verbose"] = 1 if not test_mode else 2

    # start training
    analysis = tune.run(
        trainer,
        name=exp_name,
        checkpoint_freq=checkpoint_freq if not test_mode else None,
        checkpoint_at_end=True,
        stop=stop,
        config=config,
        max_failures=20 if not test_mode else 1,
        reuse_actors=False,
        **kwargs)

    # save training progress as insurance
    pkl_path = "{}-{}-{}{}.pkl".format(exp_name, trainer_name, env_name,
                                       "" if not suffix else "-" + suffix)
    with open(pkl_path, "wb") as f:
        data = analysis.fetch_trial_dataframes()
        pickle.dump(data, f)
        print("Result is saved at: <{}>".format(pkl_path))
    return analysis
Beispiel #18
0
        #     "path": "data/yaml/ppo-300-agents.yaml",
        # },
        {
            "number": 5,
            "path": "data/yaml/es-30-agents-0818.yaml"
        }
    ]

    yaml_output_path = "delete_me_please.yaml"

    name_ckpt_mapping = read_batch_yaml(yaml_path_dict_list)
    yaml_output_path = save_yaml(name_ckpt_mapping, yaml_output_path)

    ret = get_fft_cluster_finder(
        yaml_path=yaml_output_path,
        num_rollouts=num_rollouts,
        num_workers=num_workers,
        num_gpus=3
    )
    print(ret)
    return ret


if __name__ == '__main__':
    import os

    initialize_ray(test_mode=True, num_gpus=3)

    # os.chdir("../../")
    ret = test_get_fft_cluster_finder()
Beispiel #19
0
def init_ray():
    initialize_ray(num_gpus=4,
                   test_mode=False,
                   object_store_memory=40 * int(1e9))
Beispiel #20
0
def copy_files(exp_base_dir, file_dict=None):
    file_dict = file_dict or DEFAULT_FILES_PATH
    output_file_dict = {}
    for output_name, file_path in file_dict.items():
        with open(file_path, 'r') as f:
            data = f.read()
        output_path = osp.join(exp_base_dir, output_name)
        with open(output_path, 'w') as f:
            f.write(data)
        output_file_dict[output_name] = output_path
    return output_file_dict


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--yaml-path", type=str, required=True)
    parser.add_argument("--exp-dir", type=str, required=True)
    parser.add_argument("--test-mode", action="store_true")
    parser.add_argument("--num-gpus", "-g", type=int, default=4)
    args = parser.parse_args()

    initialize_ray(test_mode=args.test_mode, num_gpus=args.num_gpus)

    name_path_dict = generate_gif(args.yaml_path, args.exp_dir)
    print("Finish generate gif.")
    generate_json(args.yaml_path, name_path_dict, args.exp_dir)
    print("Finish generate json.")
    output_file_dict = copy_files(args.exp_dir)
    print("Finish generate html.")
    print("Generate files: ", output_file_dict)
Beispiel #21
0
def get_fft_cluster_finder(
        yaml_path,
        normalize="std",
        try_standardize=False,
        num_agents=None,
        num_seeds=1,
        num_rollouts=100,
        num_workers=10,
        padding="fix",
        padding_length=500,
        padding_value=0,
        show=False,
        num_gpus=0
):
    assert yaml_path.endswith('yaml')
    initialize_ray(num_gpus=num_gpus)

    name_ckpt_mapping = get_name_ckpt_mapping(yaml_path, num_agents)
    print("Successfully loaded name_ckpt_mapping!")

    num_agents = num_agents or len(name_ckpt_mapping)
    # prefix: data/XXX_10agent_100rollout_1seed_28sm29sk
    # /XXX_10agent_100rollout_1seed_28sm29sk
    dir = osp.dirname(yaml_path)
    base = osp.basename(yaml_path)
    prefix = "".join(
        [
            base.split('.yaml')[0], "_{}agents_{}rollout_{}seed_{}".format(
                num_agents, num_rollouts, num_seeds, get_random_string()
            )
        ]
    )
    os.mkdir(osp.join(dir, prefix))
    prefix = osp.join(dir, prefix, prefix)

    data_frame_dict, repr_dict = get_fft_representation(
        name_ckpt_mapping,
        num_seeds,
        num_rollouts,
        normalize=normalize,
        num_workers=num_workers
    )
    print("Successfully get FFT representation!")

    cluster_df = parse_representation_dict(
        repr_dict, padding, padding_length, padding_value
    )
    print("Successfully get cluster dataframe!")

    # Store
    assert isinstance(cluster_df, pandas.DataFrame)
    pkl_path = prefix + '.pkl'
    cluster_df.to_pickle(pkl_path)
    print("Successfully store cluster_df! Save at: {}".format(pkl_path))

    # Cluster
    nostd_cluster_finder = ClusterFinder(cluster_df, standardize=False)
    nostd_fig_path = prefix + '_nostd.png'
    nostd_cluster_finder.display(save=nostd_fig_path, show=show)
    print(
        "Successfully finish no-standardized clustering! Save at: {}".
        format(nostd_fig_path)
    )

    ret = {
        "cluster_finder": {
            'nostd_cluster_finder': nostd_cluster_finder
        },
        "prefix": prefix,
        "cluster_df": cluster_df,
        "data_frame_dict": data_frame_dict,
        "repr_dict": repr_dict
    }

    if try_standardize:
        std_cluster_finder = ClusterFinder(cluster_df, standardize=True)
        std_fig_path = prefix + "_std.png"
        std_cluster_finder.display(save=std_fig_path, show=show)
        print(
            "Successfully finish standardized clustering! Save at: {}".
            format(std_fig_path)
        )
        ret['cluster_finder']["std_cluster_finder"] = std_cluster_finder

    return ret
Beispiel #22
0
def get_fft_representation(
        name_ckpt_mapping,
        num_seeds,
        num_rollouts,
        padding="fix",
        padding_length=500,
        padding_value=0,
        stack=False,
        normalize="range",
        num_workers=10
):
    initialize_ray()

    data_frame_dict = {}
    representation_dict = {}

    num_agents = len(name_ckpt_mapping)

    num_iteration = int(ceil(num_agents / num_workers))

    agent_ckpt_dict_range = list(name_ckpt_mapping.items())
    agent_count = 1
    agent_count_get = 1

    workers = [FFTWorker.remote() for _ in range(num_workers)]
    now_t_get = now_t = start_t = time.time()

    for iteration in range(num_iteration):
        start = iteration * num_workers
        end = min((iteration + 1) * num_workers, num_agents)
        df_obj_ids = []
        for i, (name, ckpt_dict) in enumerate(agent_ckpt_dict_range[start:end]
                                              ):
            ckpt = ckpt_dict["path"]
            env_name = ckpt_dict["env_name"]
            run_name = ckpt_dict["run_name"]
            env_maker = get_env_maker(env_name)
            workers[i].reset.remote(
                run_name=run_name,
                ckpt=ckpt,
                num_rollouts=num_rollouts,
                env_name=env_name,
                env_maker=env_maker,
                agent_name=name,
                padding=padding,
                padding_length=padding_length,
                padding_value=padding_value,
                worker_name="Worker{}".format(i)
            )

            df_obj_id = workers[i].fft.remote(
                normalize=normalize,
                _extra_name="[{}/{}] ".format(agent_count, num_agents)
            )

            print(
                "[{}/{}] (+{:.1f}s/{:.1f}s) Start collecting data from agent "
                "<{}>".format(
                    agent_count, num_agents,
                    time.time() - now_t,
                    time.time() - start_t, name
                )
            )
            agent_count += 1
            now_t = time.time()
            df_obj_ids.append(df_obj_id)

        for df_obj_id, (name, _) in zip(df_obj_ids,
                                        agent_ckpt_dict_range[start:end]):
            df, rep = copy.deepcopy(ray.get(df_obj_id))
            data_frame_dict[name] = df
            representation_dict[name] = rep
            print(
                "[{}/{}] (+{:.1f}s/{:.1f}s) Got data from agent <{}>".format(
                    agent_count_get, num_agents,
                    time.time() - now_t_get,
                    time.time() - start_t, name
                )
            )
            agent_count_get += 1
            now_t_get = time.time()
    return data_frame_dict, representation_dict
Beispiel #23
0
    parser.add_argument("--mean", type=float, default=1.0)
    parser.add_argument("--mask-mode", type=str, default="multiply")
    args = parser.parse_args()

    yaml_path = args.yaml_path
    num_agents = args.num_agents
    num_rollouts = args.num_rollouts
    num_workers = args.num_workers
    num_children = args.num_children

    normal_std = args.std
    normal_mean = args.mean

    dir_name = args.output_path

    from toolbox.utils import initialize_ray

    initialize_ray(test_mode=True)

    symbolic_agent_rollout(
        yaml_path,
        num_agents,
        num_rollouts,
        num_workers,
        num_children,
        normal_std,
        normal_mean,
        dir_name,
        mask_mode=args.mask_mode
    )
Beispiel #24
0
def get_ablation_result(ckpt,
                        run_name,
                        env_name,
                        env_maker,
                        num_rollouts,
                        layer_name,
                        num_units,
                        agent_name,
                        num_worker=10,
                        local_mode=False,
                        save=None,
                        _num_steps=None):
    initialize_ray(local_mode)
    workers = [AblationWorker.remote() for _ in range(num_worker)]
    now_t_get = now_t = start_t = time.time()
    agent_count = 1
    agent_count_get = 1
    num_iteration = int(ceil(num_units / num_worker))

    result_dict = {}
    result_obj_ids = []
    kl_obj_ids = []

    # unit index None stand for the baseline test, that is no unit is ablated.
    baseline_worker = AblationWorker.remote()
    baseline_worker.reset.remote(run_name=run_name,
                                 ckpt=ckpt,
                                 env_name=env_name,
                                 env_maker=env_maker,
                                 agent_name=agent_name,
                                 worker_name="Baseline Worker")

    baseline_result = copy.deepcopy(
        ray.get(
            baseline_worker.ablate.remote(
                num_rollouts=num_rollouts,
                layer_name=layer_name,
                unit_index=None,  # None stand for no ablation
                return_trajectory=True,
                _num_steps=_num_steps,
                save=save,
            )))

    baseline_trajectory_list = copy.deepcopy(baseline_result.pop("trajectory"))
    baseline_result["kl_divergence"] = 0.0
    result_dict[_get_unit_name(layer_name, None)] = baseline_result

    unit_index_list = list(range(num_units))

    for iteration in range(num_iteration):
        start = iteration * num_worker
        end = min((iteration + 1) * num_worker, len(unit_index_list))

        for worker_index, unit_index in enumerate(unit_index_list[start:end]):
            workers[worker_index].reset.remote(
                run_name=run_name,
                ckpt=ckpt,
                env_name=env_name,
                env_maker=env_maker,
                agent_name=agent_name,
                worker_name="Worker{}".format(worker_index))

            obj_id = workers[worker_index].ablate.remote(
                num_rollouts=num_rollouts,
                layer_name=layer_name,
                unit_index=unit_index,
                save=save,
                _num_steps=_num_steps)
            result_obj_ids.append(obj_id)

            kl_obj_id = workers[worker_index].compute_kl_divergence.remote(
                baseline_trajectory_list)
            kl_obj_ids.append(kl_obj_id)

            print("{}/{} (Unit {}) (+{:.1f}s/{:.1f}s) Start collecting data.".
                  format(
                      agent_count,
                      len(unit_index_list),
                      unit_index,
                      time.time() - now_t,
                      time.time() - start_t,
                  ))
            agent_count += 1
            now_t = time.time()
        for obj_id, kl_obj_id in zip(result_obj_ids, kl_obj_ids):
            result = copy.deepcopy(ray.get(obj_id))
            layer_name = result["layer_name"]
            unit_index = result["ablated_unit_index"]
            unit_name = _get_unit_name(layer_name, unit_index)
            result["kl_divergence"] = ray.get(kl_obj_id)
            result_dict[unit_name] = result
            print("{}/{} (Unit {}) (+{:.1f}s/{:.1f}s) Start collecting data.".
                  format(agent_count_get, len(unit_index_list), unit_index,
                         time.time() - now_t_get,
                         time.time() - start_t))
            agent_count_get += 1
            now_t_get = time.time()
        result_obj_ids.clear()
        kl_obj_ids.clear()

    ret = _parse_result_dict(result_dict)
    return ret