Пример #1
0
def test_run_hinge_success():
    env_suite = "kitchen"
    env_name = "hinge_cabinet"
    env_kwargs = dict(
        reward_type="sparse",
        use_image_obs=True,
        action_scale=1.4,
        use_workspace_limits=True,
        control_mode="primitives",
        usage_kwargs=dict(
            use_dm_backend=True,
            use_raw_action_wrappers=False,
            unflatten_images=False,
            max_path_length=5,
        ),
        action_space_kwargs=dict(),
    )
    env = make_env(
        env_suite,
        env_name,
        env_kwargs,
    )
    env.reset()
    ctr = 0
    max_path_length = 5
    for _ in range(max_path_length):
        a = np.zeros(env.action_space.low.size)
        if ctr % max_path_length == 0:
            env.reset()
            a[env.get_idx_from_primitive_name("lift")] = 1
            a[env.num_primitives +
              env.primitive_name_to_action_idx["lift"]] = 1
        if ctr % max_path_length == 1:
            a[env.get_idx_from_primitive_name("angled_x_y_grasp")] = 1
            a[env.num_primitives +
              np.array(env.primitive_name_to_action_idx["angled_x_y_grasp"]
                       )] = np.array([-np.pi / 6, -0.3, 1.4, 0])
        if ctr % max_path_length == 2:
            a[env.get_idx_from_primitive_name("move_delta_ee_pose")] = 1
            a[env.num_primitives +
              np.array(env.primitive_name_to_action_idx["move_delta_ee_pose"]
                       )] = np.array(np.array([0.5, -1, 0]))
        if ctr % max_path_length == 3:
            a[env.get_idx_from_primitive_name("rotate_about_x_axis")] = 1
            a[env.num_primitives +
              np.array(env.primitive_name_to_action_idx["rotate_about_x_axis"]
                       )] = np.array([
                           1,
                       ])
        if ctr % max_path_length == 4:
            a[env.get_idx_from_primitive_name("rotate_about_x_axis")] = 1
            a[env.num_primitives +
              np.array(env.primitive_name_to_action_idx["rotate_about_x_axis"]
                       )] = np.array([
                           0,
                       ])
        o, r, d, i = env.step(a / 1.4, )
        ctr += 1
    assert r == 1.0
Пример #2
0
def test_light_switch_success():
    env_suite = "kitchen"
    env_name = "light_switch"
    env_kwargs = dict(
        reward_type="sparse",
        use_image_obs=True,
        action_scale=1.4,
        use_workspace_limits=True,
        control_mode="primitives",
        usage_kwargs=dict(
            use_dm_backend=True,
            use_raw_action_wrappers=False,
            unflatten_images=False,
            max_path_length=5,
        ),
        action_space_kwargs=dict(),
    )
    env = make_env(
        env_suite,
        env_name,
        env_kwargs,
    )
    env.reset()

    max_path_length = 5
    ctr = 0
    for i in range(max_path_length):
        a = np.zeros(env.action_space.low.size)
        if ctr % max_path_length == 0:
            a[env.get_idx_from_primitive_name("close_gripper")] = 1
            a[
                env.num_primitives + np.array(env.primitive_name_to_action_idx["lift"])
            ] = 1
        if ctr % max_path_length == 1:
            a[env.get_idx_from_primitive_name("lift")] = 1
            a[
                env.num_primitives + np.array(env.primitive_name_to_action_idx["lift"])
            ] = 0.6
        if ctr % max_path_length == 2:
            a[env.get_idx_from_primitive_name("move_right")] = 1
            a[
                env.num_primitives + env.primitive_name_to_action_idx["move_right"]
            ] = 0.45
        if ctr % max_path_length == 3:
            a[env.get_idx_from_primitive_name("move_forward")] = 1
            a[
                env.num_primitives + env.primitive_name_to_action_idx["move_forward"]
            ] = 1.25
        if ctr % max_path_length == 4:
            a[env.get_idx_from_primitive_name("move_left")] = 1
            a[env.num_primitives + env.primitive_name_to_action_idx["move_left"]] = 0.45
        o, r, d, _ = env.step(
            a / 1.4,
        )
        ctr += 1
    assert r == 1.0
Пример #3
0
 def reset(self):
     if hasattr(self, "env"):
         del self.env
         gc.collect()
     self.idx = self.num_resets % self.num_multitask_envs
     env_name = self.env_names[self.idx]
     self.env = primitives_make_env.make_env(self.env_suite, env_name,
                                             self.env_kwargs)
     o = self.env.reset()
     self.num_resets += 1
     o = np.concatenate((o, self.get_one_hot(self.idx)))
     return o
Пример #4
0
def test_run_kettle_success():
    env_suite = "kitchen"
    env_name = "kettle"
    env_kwargs = dict(
        reward_type="sparse",
        use_image_obs=True,
        action_scale=1.4,
        use_workspace_limits=True,
        control_mode="primitives",
        usage_kwargs=dict(
            use_dm_backend=True,
            use_raw_action_wrappers=False,
            unflatten_images=False,
            max_path_length=5,
        ),
        action_space_kwargs=dict(),
    )
    env = make_env(
        env_suite,
        env_name,
        env_kwargs,
    )
    env.reset()
    ctr = 0
    max_path_length = 5
    for i in range(max_path_length):
        a = np.zeros(env.action_space.low.size)
        if ctr % max_path_length == 0:
            env.reset()
            a[env.get_idx_from_primitive_name("drop")] = 1
            a[env.num_primitives +
              env.primitive_name_to_action_idx["drop"]] = 0.5
        if ctr % max_path_length == 1:
            a[env.get_idx_from_primitive_name("angled_x_y_grasp")] = 1
            a[env.num_primitives +
              np.array(env.primitive_name_to_action_idx["angled_x_y_grasp"]
                       )] = np.array([0, 0.15, 0.7, 1])
        if ctr % max_path_length == 2:
            a[env.get_idx_from_primitive_name("move_delta_ee_pose")] = 1
            a[env.num_primitives +
              np.array(env.primitive_name_to_action_idx["move_delta_ee_pose"]
                       )] = np.array([0.25, 1.0, 0.25])
        if ctr % max_path_length == 3:
            a[env.get_idx_from_primitive_name("drop")] = 1
            a[env.num_primitives +
              env.primitive_name_to_action_idx["drop"]] = 0.25
        if ctr % max_path_length == 4:
            a[env.get_idx_from_primitive_name("open_gripper")] = 1
            a[env.num_primitives +
              env.primitive_name_to_action_idx["open_gripper"]] = 1
        o, r, d, _ = env.step(a / 1.4)
        ctr += 1
    assert r == 1.0
Пример #5
0
def test_top_burner_success():
    env_suite = "kitchen"
    env_name = "top_left_burner"
    env_kwargs = dict(
        reward_type="sparse",
        use_image_obs=True,
        action_scale=1.4,
        use_workspace_limits=True,
        control_mode="primitives",
        usage_kwargs=dict(
            use_dm_backend=True,
            use_raw_action_wrappers=False,
            unflatten_images=False,
            max_path_length=5,
        ),
        action_space_kwargs=dict(),
    )
    env = make_env(
        env_suite,
        env_name,
        env_kwargs,
    )
    env.reset()
    ctr = 0
    max_path_length = 3
    for i in range(max_path_length):
        a = np.zeros(env.action_space.low.size)
        if ctr % max_path_length == 0:
            env.reset()
            a[env.get_idx_from_primitive_name("lift")] = 1
            a[env.num_primitives +
              env.primitive_name_to_action_idx["lift"]] = 0.6
        if ctr % max_path_length == 1:
            a[env.get_idx_from_primitive_name("angled_x_y_grasp")] = 1
            a[env.num_primitives +
              np.array(env.primitive_name_to_action_idx["angled_x_y_grasp"]
                       )] = np.array([0, 0.5, 1, 1])
        if ctr % max_path_length == 2:
            a[env.get_idx_from_primitive_name("rotate_about_y_axis")] = 1
            a[env.num_primitives +
              env.primitive_name_to_action_idx["rotate_about_y_axis"]] = (
                  -np.pi / 4)

        o, r, d, _ = env.step(a / 1.4, )

        ctr += 1
    assert r == 1.0
Пример #6
0
def test_run_microwave_success():
    env_suite = "kitchen"
    env_name = "microwave"
    env_kwargs = dict(
        reward_type="sparse",
        use_image_obs=True,
        action_scale=1.4,
        use_workspace_limits=True,
        control_mode="primitives",
        usage_kwargs=dict(
            use_dm_backend=True,
            use_raw_action_wrappers=False,
            unflatten_images=False,
            max_path_length=5,
        ),
        action_space_kwargs=dict(),
    )
    env = make_env(
        env_suite,
        env_name,
        env_kwargs,
    )
    env.reset()
    ctr = 0
    max_path_length = 3
    for i in range(3):
        a = np.zeros(env.action_space.low.size)
        if ctr % max_path_length == 0:
            env.reset()
            a[env.get_idx_from_primitive_name("drop")] = 1
            a[env.num_primitives +
              env.primitive_name_to_action_idx["drop"]] = 0.55
        if ctr % max_path_length == 1:
            a[env.get_idx_from_primitive_name("angled_x_y_grasp")] = 1
            a[env.num_primitives +
              np.array(env.primitive_name_to_action_idx["angled_x_y_grasp"]
                       )] = np.array([-np.pi / 6, -0.3, 0.95, 1])
        if ctr % max_path_length == 2:
            a[env.get_idx_from_primitive_name("move_backward")] = 1
            a[env.num_primitives +
              env.primitive_name_to_action_idx["move_backward"]] = 0.6

        o, r, d, _ = env.step(a / 1.4, )
        ctr += 1
    assert r == 1.0
Пример #7
0
def test_dummy_vec_env_save_load():
    env_kwargs = dict(
        use_image_obs=True,
        imwidth=64,
        imheight=64,
        reward_type="sparse",
        usage_kwargs=dict(
            max_path_length=5,
            use_dm_backend=True,
            use_raw_action_wrappers=False,
            unflatten_images=False,
        ),
        action_space_kwargs=dict(
            control_mode="primitives",
            action_scale=1,
            camera_settings={
                "distance": 0.38227044687537043,
                "lookat": [0.21052547, 0.32329237, 0.587819],
                "azimuth": 141.328125,
                "elevation": -53.203125160653144,
            },
        ),
    )
    env_suite = "metaworld"
    env_name = "disassemble-v2"
    make_env_lambda = lambda: make_env(env_suite, env_name, env_kwargs)

    n_envs = 2
    envs = [make_env_lambda() for _ in range(n_envs)]
    env = DummyVecEnv(
        envs,
    )
    with tempfile.TemporaryDirectory() as tmpdirname:
        env.save(tmpdirname, "env.pkl")
        env = DummyVecEnv(
            envs[0:1],
        )
        new_env = env.load(tmpdirname, "env.pkl")
    assert new_env.n_envs == n_envs
Пример #8
0
def test_path_collector_save_load():
    env_kwargs = dict(
        use_image_obs=True,
        imwidth=64,
        imheight=64,
        reward_type="sparse",
        usage_kwargs=dict(
            max_path_length=5,
            use_dm_backend=True,
            use_raw_action_wrappers=False,
            unflatten_images=False,
        ),
        action_space_kwargs=dict(
            control_mode="primitives",
            action_scale=1,
            camera_settings={
                "distance": 0.38227044687537043,
                "lookat": [0.21052547, 0.32329237, 0.587819],
                "azimuth": 141.328125,
                "elevation": -53.203125160653144,
            },
        ),
    )
    actor_kwargs = dict(
        discrete_continuous_dist=True,
        init_std=0.0,
        num_layers=4,
        min_std=0.1,
        dist="tanh_normal_dreamer_v1",
    )
    model_kwargs = dict(
        model_hidden_size=400,
        stochastic_state_size=50,
        deterministic_state_size=200,
        rssm_hidden_size=200,
        reward_num_layers=2,
        pred_discount_num_layers=3,
        gru_layer_norm=True,
        std_act="sigmoid2",
        use_prior_instead_of_posterior=False,
    )
    env_suite = "metaworld"
    env_name = "disassemble-v2"
    eval_envs = [make_env(env_suite, env_name, env_kwargs) for _ in range(1)]
    eval_env = DummyVecEnv(eval_envs, )

    discrete_continuous_dist = True
    continuous_action_dim = eval_envs[0].max_arg_len
    discrete_action_dim = eval_envs[0].num_primitives
    if not discrete_continuous_dist:
        continuous_action_dim = continuous_action_dim + discrete_action_dim
        discrete_action_dim = 0
    action_dim = continuous_action_dim + discrete_action_dim
    obs_dim = eval_env.observation_space.low.size

    world_model = WorldModel(
        action_dim,
        image_shape=eval_envs[0].image_shape,
        **model_kwargs,
    )
    actor = ActorModel(
        model_kwargs["model_hidden_size"],
        world_model.feature_size,
        hidden_activation=nn.ELU,
        discrete_action_dim=discrete_action_dim,
        continuous_action_dim=continuous_action_dim,
        **actor_kwargs,
    )

    eval_policy = DreamerPolicy(
        world_model,
        actor,
        obs_dim,
        action_dim,
        exploration=False,
        expl_amount=0.0,
        discrete_action_dim=discrete_action_dim,
        continuous_action_dim=continuous_action_dim,
        discrete_continuous_dist=discrete_continuous_dist,
    )

    eval_path_collector = VecMdpPathCollector(
        eval_env,
        eval_policy,
        save_env_in_snapshot=False,
    )

    with tempfile.TemporaryDirectory() as tmpdirname:
        eval_path_collector.save(tmpdirname, "path_collector.pkl")
        eval_path_collector = VecMdpPathCollector(
            eval_env,
            eval_policy,
            save_env_in_snapshot=False,
        )
        new_path_collector = eval_path_collector.load(tmpdirname,
                                                      "path_collector.pkl")
Пример #9
0
def experiment(variant):
    import os
    import os.path as osp

    os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"

    import torch
    import torch.nn as nn

    import rlkit.envs.primitives_make_env as primitives_make_env
    import rlkit.torch.pytorch_util as ptu
    from rlkit.envs.wrappers.mujoco_vec_wrappers import (
        DummyVecEnv,
        StableBaselinesVecEnv,
    )
    from rlkit.torch.model_based.dreamer.actor_models import ActorModel
    from rlkit.torch.model_based.dreamer.dreamer_policy import (
        ActionSpaceSamplePolicy,
        DreamerPolicy,
    )
    from rlkit.torch.model_based.dreamer.dreamer_v2 import DreamerV2Trainer
    from rlkit.torch.model_based.dreamer.episode_replay_buffer import (
        EpisodeReplayBuffer,
        EpisodeReplayBufferLowLevelRAPS,
    )
    from rlkit.torch.model_based.dreamer.mlp import Mlp
    from rlkit.torch.model_based.dreamer.path_collector import VecMdpPathCollector
    from rlkit.torch.model_based.dreamer.visualization import post_epoch_visualize_func
    from rlkit.torch.model_based.dreamer.world_models import WorldModel
    from rlkit.torch.model_based.rl_algorithm import TorchBatchRLAlgorithm

    env_suite = variant.get("env_suite", "kitchen")
    env_name = variant["env_name"]
    env_kwargs = variant["env_kwargs"]
    use_raw_actions = variant["use_raw_actions"]
    num_expl_envs = variant["num_expl_envs"]
    if num_expl_envs > 1:
        env_fns = [
            lambda: primitives_make_env.make_env(
                env_suite, env_name, env_kwargs) for _ in range(num_expl_envs)
        ]
        expl_env = StableBaselinesVecEnv(
            env_fns=env_fns,
            start_method="fork",
            reload_state_args=(
                num_expl_envs,
                primitives_make_env.make_env,
                (env_suite, env_name, env_kwargs),
            ),
        )
    else:
        expl_envs = [
            primitives_make_env.make_env(env_suite, env_name, env_kwargs)
        ]
        expl_env = DummyVecEnv(expl_envs,
                               pass_render_kwargs=variant.get(
                                   "pass_render_kwargs", False))
    eval_envs = [
        primitives_make_env.make_env(env_suite, env_name, env_kwargs)
        for _ in range(1)
    ]
    eval_env = DummyVecEnv(eval_envs,
                           pass_render_kwargs=variant.get(
                               "pass_render_kwargs", False))
    if use_raw_actions:
        discrete_continuous_dist = False
        continuous_action_dim = eval_env.action_space.low.size
        discrete_action_dim = 0
        use_batch_length = True
        action_dim = continuous_action_dim
    else:
        discrete_continuous_dist = variant["actor_kwargs"][
            "discrete_continuous_dist"]
        continuous_action_dim = eval_envs[0].max_arg_len
        discrete_action_dim = eval_envs[0].num_primitives
        if not discrete_continuous_dist:
            continuous_action_dim = continuous_action_dim + discrete_action_dim
            discrete_action_dim = 0
        action_dim = continuous_action_dim + discrete_action_dim
        use_batch_length = False
    obs_dim = expl_env.observation_space.low.size

    world_model = WorldModel(
        action_dim,
        image_shape=eval_envs[0].image_shape,
        **variant["model_kwargs"],
    )
    actor = ActorModel(
        variant["model_kwargs"]["model_hidden_size"],
        world_model.feature_size,
        hidden_activation=nn.ELU,
        discrete_action_dim=discrete_action_dim,
        continuous_action_dim=continuous_action_dim,
        **variant["actor_kwargs"],
    )
    vf = Mlp(
        hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] *
        variant["vf_kwargs"]["num_layers"],
        output_size=1,
        input_size=world_model.feature_size,
        hidden_activation=nn.ELU,
    )
    target_vf = Mlp(
        hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] *
        variant["vf_kwargs"]["num_layers"],
        output_size=1,
        input_size=world_model.feature_size,
        hidden_activation=nn.ELU,
    )
    if variant.get("models_path", None) is not None:
        filename = variant["models_path"]
        actor.load_state_dict(torch.load(osp.join(filename, "actor.ptc")))
        vf.load_state_dict(torch.load(osp.join(filename, "vf.ptc")))
        target_vf.load_state_dict(
            torch.load(osp.join(filename, "target_vf.ptc")))
        world_model.load_state_dict(
            torch.load(osp.join(filename, "world_model.ptc")))
        print("LOADED MODELS")

    expl_policy = DreamerPolicy(
        world_model,
        actor,
        obs_dim,
        action_dim,
        exploration=True,
        expl_amount=variant.get("expl_amount", 0.3),
        discrete_action_dim=discrete_action_dim,
        continuous_action_dim=continuous_action_dim,
        discrete_continuous_dist=discrete_continuous_dist,
    )
    eval_policy = DreamerPolicy(
        world_model,
        actor,
        obs_dim,
        action_dim,
        exploration=False,
        expl_amount=0.0,
        discrete_action_dim=discrete_action_dim,
        continuous_action_dim=continuous_action_dim,
        discrete_continuous_dist=discrete_continuous_dist,
    )

    rand_policy = ActionSpaceSamplePolicy(expl_env)

    expl_path_collector = VecMdpPathCollector(
        expl_env,
        expl_policy,
        save_env_in_snapshot=False,
    )

    eval_path_collector = VecMdpPathCollector(
        eval_env,
        eval_policy,
        save_env_in_snapshot=False,
    )

    variant["replay_buffer_kwargs"]["use_batch_length"] = use_batch_length
    replay_buffer = EpisodeReplayBuffer(
        num_expl_envs,
        obs_dim,
        action_dim,
        **variant["replay_buffer_kwargs"],
    )
    eval_filename = variant.get("eval_buffer_path", None)
    if eval_filename is not None:
        eval_buffer = EpisodeReplayBufferLowLevelRAPS(
            1000,
            expl_env,
            variant["algorithm_kwargs"]["max_path_length"],
            10,
            obs_dim,
            action_dim,
            9,
            replace=False,
        )
        eval_buffer.load_buffer(eval_filename, eval_env.envs[0].num_primitives)
    else:
        eval_buffer = None
    trainer = DreamerV2Trainer(
        actor,
        vf,
        target_vf,
        world_model,
        eval_envs[0].image_shape,
        **variant["trainer_kwargs"],
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        pretrain_policy=rand_policy,
        eval_buffer=eval_buffer,
        **variant["algorithm_kwargs"],
    )
    algorithm.low_level_primitives = False
    if variant.get("generate_video", False):
        post_epoch_visualize_func(algorithm, 0)
    else:
        if variant.get("save_video", False):
            algorithm.post_epoch_funcs.append(post_epoch_visualize_func)
        print("TRAINING")
        algorithm.to(ptu.device)
        algorithm.train()
        if variant.get("save_video", False):
            post_epoch_visualize_func(algorithm, -1)
def run_trained_policy(path):
    ptu.set_gpu_mode(True)
    variant = json.load(open(osp.join(path, "variant.json"), "r"))
    set_seed(variant["seed"])
    variant = preprocess_variant_llraps(variant)
    env_suite = variant.get("env_suite", "kitchen")
    env_kwargs = variant["env_kwargs"]
    num_low_level_actions_per_primitive = variant[
        "num_low_level_actions_per_primitive"]
    low_level_action_dim = variant["low_level_action_dim"]

    env_name = variant["env_name"]
    make_env_lambda = lambda: make_env(env_suite, env_name, env_kwargs)

    eval_envs = [make_env_lambda() for _ in range(1)]
    eval_env = DummyVecEnv(eval_envs,
                           pass_render_kwargs=variant.get(
                               "pass_render_kwargs", False))

    discrete_continuous_dist = variant["actor_kwargs"][
        "discrete_continuous_dist"]
    num_primitives = eval_envs[0].num_primitives
    continuous_action_dim = eval_envs[0].max_arg_len
    discrete_action_dim = num_primitives
    if not discrete_continuous_dist:
        continuous_action_dim = continuous_action_dim + discrete_action_dim
        discrete_action_dim = 0
    action_dim = continuous_action_dim + discrete_action_dim
    obs_dim = eval_env.observation_space.low.size

    primitive_model = Mlp(
        output_size=variant["low_level_action_dim"],
        input_size=variant["model_kwargs"]["stochastic_state_size"] +
        variant["model_kwargs"]["deterministic_state_size"] +
        eval_env.envs[0].action_space.low.shape[0] + 1,
        hidden_activation=nn.ReLU,
        num_embeddings=eval_envs[0].num_primitives,
        embedding_dim=eval_envs[0].num_primitives,
        embedding_slice=eval_envs[0].num_primitives,
        **variant["primitive_model_kwargs"],
    )

    world_model = LowlevelRAPSWorldModel(
        low_level_action_dim,
        image_shape=eval_envs[0].image_shape,
        primitive_model=primitive_model,
        **variant["model_kwargs"],
    )
    actor = ActorModel(
        variant["model_kwargs"]["model_hidden_size"],
        world_model.feature_size,
        hidden_activation=nn.ELU,
        discrete_action_dim=discrete_action_dim,
        continuous_action_dim=continuous_action_dim,
        **variant["actor_kwargs"],
    )
    actor.load_state_dict(torch.load(osp.join(path, "actor.ptc")))
    world_model.load_state_dict(torch.load(osp.join(path, "world_model.ptc")))

    actor.to(ptu.device)
    world_model.to(ptu.device)

    eval_policy = DreamerLowLevelRAPSPolicy(
        world_model,
        actor,
        obs_dim,
        action_dim,
        num_low_level_actions_per_primitive=num_low_level_actions_per_primitive,
        low_level_action_dim=low_level_action_dim,
        exploration=False,
        expl_amount=0.0,
        discrete_action_dim=discrete_action_dim,
        continuous_action_dim=continuous_action_dim,
        discrete_continuous_dist=discrete_continuous_dist,
    )
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            for step in range(
                    0, variant["algorithm_kwargs"]["max_path_length"] + 1):
                if step == 0:
                    observation = eval_env.envs[0].reset()
                    eval_policy.reset(observation.reshape(1, -1))
                    policy_o = (None, observation.reshape(1, -1))
                    reward = 0
                else:
                    high_level_action, _ = eval_policy.get_action(policy_o, )
                    observation, reward, done, info = eval_env.envs[0].step(
                        high_level_action[0], )
                    low_level_obs = np.expand_dims(
                        np.array(info["low_level_obs"]), 0)
                    low_level_action = np.expand_dims(
                        np.array(info["low_level_action"]), 0)
                    policy_o = (low_level_action, low_level_obs)
    return reward
Пример #11
0
     image_kwargs=dict(imwidth=64, imheight=64),
     collect_primitives_info=True,
     include_phase_variable=True,
     render_intermediate_obs_to_info=not args.collect_data_fn
     == "collect_primitive_cloning_data",
     num_low_level_actions_per_primitive=num_low_level_actions_per_primitive,
 )
 datafile = "wm_H_{}_T_{}_E_{}_P_{}_raps_ll_hl_even_rt_{}".format(
     args.max_path_length,
     num_trajs,
     args.num_envs,
     num_low_level_actions_per_primitive,
     env_name,
 )
 env_fns = [
     lambda: make_env(env_suite, env_name, env_kwargs)
     for _ in range(args.num_envs)
 ]
 env = StableBaselinesVecEnv(env_fns=env_fns, start_method="fork")
 if args.collect_data_fn == "collect_world_model_data":
     data = collect_world_model_data(
         env,
         num_trajs * args.num_envs,
         args.num_envs,
         args.max_path_length,
     )
     save_data(data, datafile)
 elif (
     args.collect_data_fn
     == "collect_world_model_data_low_level_primitives"
 ):
Пример #12
0
def experiment(variant):
    import os

    import rlkit.envs.primitives_make_env as primitives_make_env

    os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"
    import torch

    import rlkit.torch.pytorch_util as ptu
    from rlkit.envs.wrappers.mujoco_vec_wrappers import (
        DummyVecEnv,
        StableBaselinesVecEnv,
    )
    from rlkit.torch.model_based.dreamer.actor_models import ActorModel
    from rlkit.torch.model_based.dreamer.dreamer_policy import (
        ActionSpaceSamplePolicy,
        DreamerPolicy,
    )
    from rlkit.torch.model_based.dreamer.episode_replay_buffer import (
        EpisodeReplayBuffer, )
    from rlkit.torch.model_based.dreamer.mlp import Mlp
    from rlkit.torch.model_based.dreamer.path_collector import VecMdpPathCollector
    from rlkit.torch.model_based.dreamer.visualization import video_post_epoch_func
    from rlkit.torch.model_based.dreamer.world_models import WorldModel
    from rlkit.torch.model_based.plan2explore.latent_space_models import (
        OneStepEnsembleModel, )
    from rlkit.torch.model_based.plan2explore.plan2explore import Plan2ExploreTrainer
    from rlkit.torch.model_based.rl_algorithm import TorchBatchRLAlgorithm

    env_suite = variant.get("env_suite", "kitchen")
    env_name = variant["env_name"]
    env_kwargs = variant["env_kwargs"]
    use_raw_actions = variant["use_raw_actions"]
    num_expl_envs = variant["num_expl_envs"]
    actor_model_class_name = variant.get("actor_model_class", "actor_model")

    if num_expl_envs > 1:
        env_fns = [
            lambda: primitives_make_env.make_env(
                env_suite, env_name, env_kwargs) for _ in range(num_expl_envs)
        ]
        expl_env = StableBaselinesVecEnv(env_fns=env_fns, start_method="fork")
    else:
        expl_envs = [
            primitives_make_env.make_env(env_suite, env_name, env_kwargs)
        ]
        expl_env = DummyVecEnv(expl_envs,
                               pass_render_kwargs=variant.get(
                                   "pass_render_kwargs", False))
    eval_envs = [
        primitives_make_env.make_env(env_suite, env_name, env_kwargs)
        for _ in range(1)
    ]
    eval_env = DummyVecEnv(eval_envs,
                           pass_render_kwargs=variant.get(
                               "pass_render_kwargs", False))
    if use_raw_actions:
        discrete_continuous_dist = False
        continuous_action_dim = eval_env.action_space.low.size
        discrete_action_dim = 0
        use_batch_length = True
        action_dim = continuous_action_dim
    else:
        discrete_continuous_dist = variant["actor_kwargs"][
            "discrete_continuous_dist"]
        continuous_action_dim = eval_envs[0].max_arg_len
        discrete_action_dim = eval_envs[0].num_primitives
        if not discrete_continuous_dist:
            continuous_action_dim = continuous_action_dim + discrete_action_dim
            discrete_action_dim = 0
        action_dim = continuous_action_dim + discrete_action_dim
        use_batch_length = False
    world_model_class = WorldModel
    obs_dim = expl_env.observation_space.low.size
    actor_model_class = ActorModel
    if variant.get("load_from_path", False):
        data = torch.load(variant["models_path"])
        actor = data["trainer/actor"]
        vf = data["trainer/vf"]
        target_vf = data["trainer/target_vf"]
        world_model = data["trainer/world_model"]
    else:
        world_model = world_model_class(
            action_dim,
            image_shape=eval_envs[0].image_shape,
            **variant["model_kwargs"],
            env=eval_envs[0].env,
        )
        actor = actor_model_class(
            variant["model_kwargs"]["model_hidden_size"],
            world_model.feature_size,
            hidden_activation=torch.nn.functional.elu,
            discrete_action_dim=discrete_action_dim,
            continuous_action_dim=continuous_action_dim,
            env=eval_envs[0].env,
            **variant["actor_kwargs"],
        )
        vf = Mlp(
            hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] *
            variant["vf_kwargs"]["num_layers"],
            output_size=1,
            input_size=world_model.feature_size,
            hidden_activation=torch.nn.functional.elu,
        )
        target_vf = Mlp(
            hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] *
            variant["vf_kwargs"]["num_layers"],
            output_size=1,
            input_size=world_model.feature_size,
            hidden_activation=torch.nn.functional.elu,
        )

    one_step_ensemble = OneStepEnsembleModel(
        action_dim=action_dim,
        embedding_size=variant["model_kwargs"]["embedding_size"],
        deterministic_state_size=variant["model_kwargs"]
        ["deterministic_state_size"],
        stochastic_state_size=variant["model_kwargs"]["stochastic_state_size"],
        **variant["one_step_ensemble_kwargs"],
    )

    exploration_actor = actor_model_class(
        variant["model_kwargs"]["model_hidden_size"],
        world_model.feature_size,
        hidden_activation=torch.nn.functional.elu,
        discrete_action_dim=discrete_action_dim,
        continuous_action_dim=continuous_action_dim,
        env=eval_envs[0],
        **variant["actor_kwargs"],
    )
    exploration_vf = Mlp(
        hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] *
        variant["vf_kwargs"]["num_layers"],
        output_size=1,
        input_size=world_model.feature_size,
        hidden_activation=torch.nn.functional.elu,
    )
    exploration_target_vf = Mlp(
        hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] *
        variant["vf_kwargs"]["num_layers"],
        output_size=1,
        input_size=world_model.feature_size,
        hidden_activation=torch.nn.functional.elu,
    )

    if variant.get("expl_with_exploration_actor", True):
        expl_actor = exploration_actor
    else:
        expl_actor = actor
    expl_policy = DreamerPolicy(
        world_model,
        expl_actor,
        obs_dim,
        action_dim,
        exploration=True,
        expl_amount=variant.get("expl_amount", 0.3),
        discrete_action_dim=discrete_action_dim,
        continuous_action_dim=continuous_action_dim,
        discrete_continuous_dist=variant["actor_kwargs"]
        ["discrete_continuous_dist"],
    )
    if variant.get("eval_with_exploration_actor", False):
        eval_actor = exploration_actor
    else:
        eval_actor = actor
    eval_policy = DreamerPolicy(
        world_model,
        eval_actor,
        obs_dim,
        action_dim,
        exploration=False,
        expl_amount=0.0,
        discrete_action_dim=discrete_action_dim,
        continuous_action_dim=continuous_action_dim,
        discrete_continuous_dist=variant["actor_kwargs"]
        ["discrete_continuous_dist"],
    )

    rand_policy = ActionSpaceSamplePolicy(expl_env)

    expl_path_collector = VecMdpPathCollector(
        expl_env,
        expl_policy,
        save_env_in_snapshot=False,
    )

    eval_path_collector = VecMdpPathCollector(
        eval_env,
        eval_policy,
        save_env_in_snapshot=False,
    )

    replay_buffer = EpisodeReplayBuffer(
        variant["replay_buffer_size"],
        expl_env,
        variant["algorithm_kwargs"]["max_path_length"] + 1,
        obs_dim,
        action_dim,
        replace=False,
        use_batch_length=use_batch_length,
    )
    trainer = Plan2ExploreTrainer(
        eval_env,
        actor,
        vf,
        target_vf,
        world_model,
        eval_envs[0].image_shape,
        exploration_actor,
        exploration_vf,
        exploration_target_vf,
        one_step_ensemble,
        **variant["trainer_kwargs"],
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        pretrain_policy=rand_policy,
        **variant["algorithm_kwargs"],
    )

    algorithm.post_epoch_funcs.append(video_post_epoch_func)
    algorithm.to(ptu.device)
    algorithm.train()
    video_post_epoch_func(algorithm, -1)
Пример #13
0
def test_run_assembly_success():
    env_suite = "metaworld"
    env_name = "assembly-v2"
    env_kwargs = dict(
        use_image_obs=True,
        imwidth=64,
        imheight=64,
        reward_type="sparse",
        usage_kwargs=dict(
            use_dm_backend=True,
            use_raw_action_wrappers=False,
            unflatten_images=False,
            max_path_length=5,
        ),
        action_space_kwargs=dict(
            control_mode="primitives",
            action_scale=1,
            camera_settings={
                "distance": 0.38227044687537043,
                "lookat": [0.21052547, 0.32329237, 0.587819],
                "azimuth": 141.328125,
                "elevation": -53.203125160653144,
            },
        ),
    )
    render_mode = "rgb_array"
    render_im_shape = (64, 64)
    render_every_step = True
    env = make_env(
        env_suite,
        env_name,
        env_kwargs,
    )
    o = env.reset()
    for i in range(5):
        a = env.action_space.sample()
        a = np.zeros_like(a)
        if i % 5 == 0:
            primitive = "top_x_y_grasp"
            a[env.get_idx_from_primitive_name(primitive)] = 1
            a[env.num_primitives +
              np.array(env.primitive_name_to_action_idx[primitive])] = [
                  0.25, 0.0, -0.6, 1
              ]
        elif i % 5 == 1:
            primitive = "lift"
            a[env.get_idx_from_primitive_name(primitive)] = 1
            a[env.num_primitives +
              np.array(env.primitive_name_to_action_idx[primitive])] = 0.4
        elif i % 5 == 2:
            primitive = "move_forward"
            a[env.get_idx_from_primitive_name(primitive)] = 1
            a[env.num_primitives +
              np.array(env.primitive_name_to_action_idx[primitive])] = 0.45
        elif i % 5 == 3:
            primitive = "move_right"
            a[env.get_idx_from_primitive_name(primitive)] = 1
            a[env.num_primitives +
              np.array(env.primitive_name_to_action_idx[primitive])] = 0.05
        elif i % 5 == 3:
            primitive = "open_gripper"
            a[env.get_idx_from_primitive_name(primitive)] = 1
            a[env.num_primitives +
              np.array(env.primitive_name_to_action_idx[primitive])] = 1
        o, r, d, info = env.step(
            a,
            render_every_step=render_every_step,
            render_mode=render_mode,
            render_im_shape=render_im_shape,
        )
    assert r == 1.0
Пример #14
0
def experiment(variant):
    import numpy as np
    import torch
    from torch import nn, optim
    from tqdm import tqdm

    import rlkit.torch.pytorch_util as ptu
    from rlkit.core import logger
    from rlkit.envs.primitives_make_env import make_env
    from rlkit.torch.model_based.dreamer.mlp import Mlp, MlpResidual
    from rlkit.torch.model_based.dreamer.train_world_model import (
        compute_world_model_loss,
        get_dataloader,
        get_dataloader_rt,
        get_dataloader_separately,
        update_network,
        visualize_rollout,
        world_model_loss_rt,
    )
    from rlkit.torch.model_based.dreamer.world_models import (
        LowlevelRAPSWorldModel,
        WorldModel,
    )

    env_suite, env_name, env_kwargs = (
        variant["env_suite"],
        variant["env_name"],
        variant["env_kwargs"],
    )
    max_path_length = variant["env_kwargs"]["max_path_length"]
    low_level_primitives = variant["low_level_primitives"]
    num_low_level_actions_per_primitive = variant[
        "num_low_level_actions_per_primitive"]
    low_level_action_dim = variant["low_level_action_dim"]
    dataloader_kwargs = variant["dataloader_kwargs"]
    env = make_env(env_suite, env_name, env_kwargs)
    world_model_kwargs = variant["model_kwargs"]
    optimizer_kwargs = variant["optimizer_kwargs"]
    gradient_clip = variant["gradient_clip"]
    if low_level_primitives:
        world_model_kwargs["action_dim"] = low_level_action_dim
    else:
        world_model_kwargs["action_dim"] = env.action_space.low.shape[0]
    image_shape = env.image_shape
    world_model_kwargs["image_shape"] = image_shape
    scaler = torch.cuda.amp.GradScaler()
    world_model_loss_kwargs = variant["world_model_loss_kwargs"]
    clone_primitives = variant["clone_primitives"]
    clone_primitives_separately = variant["clone_primitives_separately"]
    clone_primitives_and_train_world_model = variant.get(
        "clone_primitives_and_train_world_model", False)
    batch_len = variant.get("batch_len", 100)
    num_epochs = variant["num_epochs"]
    loss_to_use = variant.get("loss_to_use", "both")

    logdir = logger.get_snapshot_dir()

    if clone_primitives_separately:
        (
            train_dataloaders,
            test_dataloaders,
            train_datasets,
            test_datasets,
        ) = get_dataloader_separately(
            variant["datafile"],
            num_low_level_actions_per_primitive=
            num_low_level_actions_per_primitive,
            num_primitives=env.num_primitives,
            env=env,
            **dataloader_kwargs,
        )
    elif clone_primitives_and_train_world_model:
        print("LOADING DATA")
        (
            train_dataloader,
            test_dataloader,
            train_dataset,
            test_dataset,
        ) = get_dataloader_rt(
            variant["datafile"],
            max_path_length=max_path_length *
            num_low_level_actions_per_primitive + 1,
            **dataloader_kwargs,
        )
    elif low_level_primitives or clone_primitives:
        print("LOADING DATA")
        (
            train_dataloader,
            test_dataloader,
            train_dataset,
            test_dataset,
        ) = get_dataloader(
            variant["datafile"],
            max_path_length=max_path_length *
            num_low_level_actions_per_primitive + 1,
            **dataloader_kwargs,
        )
    else:
        train_dataloader, test_dataloader, train_dataset, test_dataset = get_dataloader(
            variant["datafile"],
            max_path_length=max_path_length + 1,
            **dataloader_kwargs,
        )

    if clone_primitives_and_train_world_model:
        if variant["mlp_act"] == "elu":
            mlp_act = nn.functional.elu
        elif variant["mlp_act"] == "relu":
            mlp_act = nn.functional.relu
        if variant["mlp_res"]:
            mlp_class = MlpResidual
        else:
            mlp_class = Mlp
        criterion = nn.MSELoss()
        primitive_model = mlp_class(
            hidden_sizes=variant["mlp_hidden_sizes"],
            output_size=low_level_action_dim,
            input_size=250 + env.action_space.low.shape[0] + 1,
            hidden_activation=mlp_act,
        ).to(ptu.device)
        world_model_class = LowlevelRAPSWorldModel
        world_model = world_model_class(
            primitive_model=primitive_model,
            **world_model_kwargs,
        ).to(ptu.device)
        optimizer = optim.Adam(
            world_model.parameters(),
            **optimizer_kwargs,
        )
        best_test_loss = np.inf
        for i in tqdm(range(num_epochs)):
            eval_statistics = OrderedDict()
            print("Epoch: ", i)
            total_primitive_loss = 0
            total_world_model_loss = 0
            total_div_loss = 0
            total_image_pred_loss = 0
            total_transition_loss = 0
            total_entropy_loss = 0
            total_pred_discount_loss = 0
            total_reward_pred_loss = 0
            total_train_steps = 0
            for data in train_dataloader:
                with torch.cuda.amp.autocast():
                    (
                        high_level_actions,
                        obs,
                        rewards,
                        terminals,
                    ), low_level_actions = data
                    obs = obs.to(ptu.device).float()
                    low_level_actions = low_level_actions.to(
                        ptu.device).float()
                    high_level_actions = high_level_actions.to(
                        ptu.device).float()
                    rewards = rewards.to(ptu.device).float()
                    terminals = terminals.to(ptu.device).float()
                    assert all(terminals[:, -1] == 1)
                    rt_idxs = np.arange(
                        num_low_level_actions_per_primitive,
                        obs.shape[1],
                        num_low_level_actions_per_primitive,
                    )
                    rt_idxs = np.concatenate(
                        [[0], rt_idxs]
                    )  # reset obs, effect of first primitive, second primitive, so on

                    batch_start = np.random.randint(0,
                                                    obs.shape[1] - batch_len,
                                                    size=(obs.shape[0]))
                    batch_indices = np.linspace(
                        batch_start,
                        batch_start + batch_len,
                        batch_len,
                        endpoint=False,
                    ).astype(int)
                    (
                        post,
                        prior,
                        post_dist,
                        prior_dist,
                        image_dist,
                        reward_dist,
                        pred_discount_dist,
                        _,
                        action_preds,
                    ) = world_model(
                        obs,
                        (high_level_actions, low_level_actions),
                        use_network_action=False,
                        batch_indices=batch_indices,
                        rt_idxs=rt_idxs,
                    )
                    obs = world_model.flatten_obs(
                        obs[np.arange(batch_indices.shape[1]),
                            batch_indices].permute(1, 0, 2),
                        (int(np.prod(image_shape)), ),
                    )
                    rewards = rewards.reshape(-1, rewards.shape[-1])
                    terminals = terminals.reshape(-1, terminals.shape[-1])
                    (
                        world_model_loss,
                        div,
                        image_pred_loss,
                        reward_pred_loss,
                        transition_loss,
                        entropy_loss,
                        pred_discount_loss,
                    ) = world_model_loss_rt(
                        world_model,
                        image_shape,
                        image_dist,
                        reward_dist,
                        {
                            key: value[np.arange(batch_indices.shape[1]),
                                       batch_indices].permute(1, 0, 2).reshape(
                                           -1, value.shape[-1])
                            for key, value in prior.items()
                        },
                        {
                            key: value[np.arange(batch_indices.shape[1]),
                                       batch_indices].permute(1, 0, 2).reshape(
                                           -1, value.shape[-1])
                            for key, value in post.items()
                        },
                        prior_dist,
                        post_dist,
                        pred_discount_dist,
                        obs,
                        rewards,
                        terminals,
                        **world_model_loss_kwargs,
                    )

                    batch_start = np.random.randint(
                        0,
                        low_level_actions.shape[1] - batch_len,
                        size=(low_level_actions.shape[0]),
                    )
                    batch_indices = np.linspace(
                        batch_start,
                        batch_start + batch_len,
                        batch_len,
                        endpoint=False,
                    ).astype(int)
                    primitive_loss = criterion(
                        action_preds[np.arange(batch_indices.shape[1]),
                                     batch_indices].permute(1, 0, 2).reshape(
                                         -1, action_preds.shape[-1]),
                        low_level_actions[:, 1:]
                        [np.arange(batch_indices.shape[1]),
                         batch_indices].permute(1, 0, 2).reshape(
                             -1, action_preds.shape[-1]),
                    )
                    total_primitive_loss += primitive_loss.item()
                    total_world_model_loss += world_model_loss.item()
                    total_div_loss += div.item()
                    total_image_pred_loss += image_pred_loss.item()
                    total_transition_loss += transition_loss.item()
                    total_entropy_loss += entropy_loss.item()
                    total_pred_discount_loss += pred_discount_loss.item()
                    total_reward_pred_loss += reward_pred_loss.item()

                    if loss_to_use == "wm":
                        loss = world_model_loss
                    elif loss_to_use == "primitive":
                        loss = primitive_loss
                    else:
                        loss = world_model_loss + primitive_loss
                    total_train_steps += 1

                update_network(world_model, optimizer, loss, gradient_clip,
                               scaler)
                scaler.update()
            eval_statistics["train/primitive_loss"] = (total_primitive_loss /
                                                       total_train_steps)
            eval_statistics["train/world_model_loss"] = (
                total_world_model_loss / total_train_steps)
            eval_statistics["train/image_pred_loss"] = (total_image_pred_loss /
                                                        total_train_steps)
            eval_statistics["train/transition_loss"] = (total_transition_loss /
                                                        total_train_steps)
            eval_statistics["train/entropy_loss"] = (total_entropy_loss /
                                                     total_train_steps)
            eval_statistics["train/pred_discount_loss"] = (
                total_pred_discount_loss / total_train_steps)
            eval_statistics["train/reward_pred_loss"] = (
                total_reward_pred_loss / total_train_steps)
            latest_state_dict = world_model.state_dict().copy()
            with torch.no_grad():
                total_primitive_loss = 0
                total_world_model_loss = 0
                total_div_loss = 0
                total_image_pred_loss = 0
                total_transition_loss = 0
                total_entropy_loss = 0
                total_pred_discount_loss = 0
                total_reward_pred_loss = 0
                total_loss = 0
                total_test_steps = 0
                for data in test_dataloader:
                    with torch.cuda.amp.autocast():
                        (
                            high_level_actions,
                            obs,
                            rewards,
                            terminals,
                        ), low_level_actions = data
                        obs = obs.to(ptu.device).float()
                        low_level_actions = low_level_actions.to(
                            ptu.device).float()
                        high_level_actions = high_level_actions.to(
                            ptu.device).float()
                        rewards = rewards.to(ptu.device).float()
                        terminals = terminals.to(ptu.device).float()
                        assert all(terminals[:, -1] == 1)
                        rt_idxs = np.arange(
                            num_low_level_actions_per_primitive,
                            obs.shape[1],
                            num_low_level_actions_per_primitive,
                        )
                        rt_idxs = np.concatenate(
                            [[0], rt_idxs]
                        )  # reset obs, effect of first primitive, second primitive, so on

                        batch_start = np.random.randint(0,
                                                        obs.shape[1] -
                                                        batch_len,
                                                        size=(obs.shape[0]))
                        batch_indices = np.linspace(
                            batch_start,
                            batch_start + batch_len,
                            batch_len,
                            endpoint=False,
                        ).astype(int)
                        (
                            post,
                            prior,
                            post_dist,
                            prior_dist,
                            image_dist,
                            reward_dist,
                            pred_discount_dist,
                            _,
                            action_preds,
                        ) = world_model(
                            obs,
                            (high_level_actions, low_level_actions),
                            use_network_action=False,
                            batch_indices=batch_indices,
                            rt_idxs=rt_idxs,
                        )
                        obs = world_model.flatten_obs(
                            obs[np.arange(batch_indices.shape[1]),
                                batch_indices].permute(1, 0, 2),
                            (int(np.prod(image_shape)), ),
                        )
                        rewards = rewards.reshape(-1, rewards.shape[-1])
                        terminals = terminals.reshape(-1, terminals.shape[-1])
                        (
                            world_model_loss,
                            div,
                            image_pred_loss,
                            reward_pred_loss,
                            transition_loss,
                            entropy_loss,
                            pred_discount_loss,
                        ) = world_model_loss_rt(
                            world_model,
                            image_shape,
                            image_dist,
                            reward_dist,
                            {
                                key: value[np.arange(batch_indices.shape[1]),
                                           batch_indices].permute(
                                               1, 0, 2).reshape(
                                                   -1, value.shape[-1])
                                for key, value in prior.items()
                            },
                            {
                                key: value[np.arange(batch_indices.shape[1]),
                                           batch_indices].permute(
                                               1, 0, 2).reshape(
                                                   -1, value.shape[-1])
                                for key, value in post.items()
                            },
                            prior_dist,
                            post_dist,
                            pred_discount_dist,
                            obs,
                            rewards,
                            terminals,
                            **world_model_loss_kwargs,
                        )

                        batch_start = np.random.randint(
                            0,
                            low_level_actions.shape[1] - batch_len,
                            size=(low_level_actions.shape[0]),
                        )
                        batch_indices = np.linspace(
                            batch_start,
                            batch_start + batch_len,
                            batch_len,
                            endpoint=False,
                        ).astype(int)
                        primitive_loss = criterion(
                            action_preds[np.arange(batch_indices.shape[1]),
                                         batch_indices].permute(
                                             1, 0, 2).reshape(
                                                 -1, action_preds.shape[-1]),
                            low_level_actions[:, 1:]
                            [np.arange(batch_indices.shape[1]),
                             batch_indices].permute(1, 0, 2).reshape(
                                 -1, action_preds.shape[-1]),
                        )
                        total_primitive_loss += primitive_loss.item()
                        total_world_model_loss += world_model_loss.item()
                        total_div_loss += div.item()
                        total_image_pred_loss += image_pred_loss.item()
                        total_transition_loss += transition_loss.item()
                        total_entropy_loss += entropy_loss.item()
                        total_pred_discount_loss += pred_discount_loss.item()
                        total_reward_pred_loss += reward_pred_loss.item()
                        total_loss += world_model_loss.item(
                        ) + primitive_loss.item()
                        total_test_steps += 1
                eval_statistics["test/primitive_loss"] = (
                    total_primitive_loss / total_test_steps)
                eval_statistics["test/world_model_loss"] = (
                    total_world_model_loss / total_test_steps)
                eval_statistics["test/image_pred_loss"] = (
                    total_image_pred_loss / total_test_steps)
                eval_statistics["test/transition_loss"] = (
                    total_transition_loss / total_test_steps)
                eval_statistics["test/entropy_loss"] = (total_entropy_loss /
                                                        total_test_steps)
                eval_statistics["test/pred_discount_loss"] = (
                    total_pred_discount_loss / total_test_steps)
                eval_statistics["test/reward_pred_loss"] = (
                    total_reward_pred_loss / total_test_steps)
                if (total_loss / total_test_steps) <= best_test_loss:
                    best_test_loss = total_loss / total_test_steps
                    os.makedirs(logdir + "/models/", exist_ok=True)
                    best_wm_state_dict = world_model.state_dict().copy()
                    torch.save(
                        best_wm_state_dict,
                        logdir + "/models/world_model.pt",
                    )
                if i % variant["plotting_period"] == 0:
                    print("Best test loss", best_test_loss)
                    world_model.load_state_dict(best_wm_state_dict)
                    visualize_wm(
                        env,
                        world_model,
                        train_dataset.outputs,
                        train_dataset.inputs[1],
                        test_dataset.outputs,
                        test_dataset.inputs[1],
                        logdir,
                        max_path_length,
                        low_level_primitives,
                        num_low_level_actions_per_primitive,
                        primitive_model=primitive_model,
                    )
                    world_model.load_state_dict(latest_state_dict)
                logger.record_dict(eval_statistics, prefix="")
                logger.dump_tabular(with_prefix=False, with_timestamp=False)

    elif clone_primitives_separately:
        world_model.load_state_dict(torch.load(variant["world_model_path"]))
        criterion = nn.MSELoss()
        primitives = []
        for p in range(env.num_primitives):
            arguments_size = train_datasets[p].inputs[0].shape[-1]
            m = Mlp(
                hidden_sizes=variant["mlp_hidden_sizes"],
                output_size=low_level_action_dim,
                input_size=world_model.feature_size + arguments_size,
                hidden_activation=torch.nn.functional.relu,
            ).to(ptu.device)
            if variant.get("primitives_path", None):
                m.load_state_dict(
                    torch.load(variant["primitives_path"] +
                               "primitive_model_{}.pt".format(p)))
            primitives.append(m)

        optimizers = [
            optim.Adam(p.parameters(), **optimizer_kwargs) for p in primitives
        ]
        for i in tqdm(range(num_epochs)):
            if i % variant["plotting_period"] == 0:
                visualize_rollout(
                    env,
                    None,
                    None,
                    world_model,
                    logdir,
                    max_path_length,
                    use_env=True,
                    forcing="none",
                    tag="none",
                    low_level_primitives=low_level_primitives,
                    num_low_level_actions_per_primitive=
                    num_low_level_actions_per_primitive,
                    primitive_model=primitives,
                    use_separate_primitives=True,
                )
                visualize_rollout(
                    env,
                    None,
                    None,
                    world_model,
                    logdir,
                    max_path_length,
                    use_env=True,
                    forcing="teacher",
                    tag="none",
                    low_level_primitives=low_level_primitives,
                    num_low_level_actions_per_primitive=
                    num_low_level_actions_per_primitive,
                    primitive_model=primitives,
                    use_separate_primitives=True,
                )
                visualize_rollout(
                    env,
                    None,
                    None,
                    world_model,
                    logdir,
                    max_path_length,
                    use_env=True,
                    forcing="self",
                    tag="none",
                    low_level_primitives=low_level_primitives,
                    num_low_level_actions_per_primitive=
                    num_low_level_actions_per_primitive,
                    primitive_model=primitives,
                    use_separate_primitives=True,
                )
            eval_statistics = OrderedDict()
            print("Epoch: ", i)
            for p, (
                    train_dataloader,
                    test_dataloader,
                    primitive_model,
                    optimizer,
            ) in enumerate(
                    zip(train_dataloaders, test_dataloaders, primitives,
                        optimizers)):
                total_loss = 0
                total_train_steps = 0
                for data in train_dataloader:
                    with torch.cuda.amp.autocast():
                        (arguments, obs), actions = data
                        obs = obs.to(ptu.device).float()
                        actions = actions.to(ptu.device).float()
                        arguments = arguments.to(ptu.device).float()
                        action_preds = world_model(
                            obs,
                            (arguments, actions),
                            primitive_model,
                            use_network_action=False,
                        )[-1]
                        loss = criterion(action_preds, actions)
                        total_loss += loss.item()
                        total_train_steps += 1

                    update_network(primitive_model, optimizer, loss,
                                   gradient_clip, scaler)
                    scaler.update()
                eval_statistics["train/primitive_loss {}".format(p)] = (
                    total_loss / total_train_steps)
                best_test_loss = np.inf
                with torch.no_grad():
                    total_loss = 0
                    total_test_steps = 0
                    for data in test_dataloader:
                        with torch.cuda.amp.autocast():
                            (high_level_actions, obs), actions = data
                            obs = obs.to(ptu.device).float()
                            actions = actions.to(ptu.device).float()
                            high_level_actions = high_level_actions.to(
                                ptu.device).float()
                            action_preds = world_model(
                                obs,
                                (high_level_actions, actions),
                                primitive_model,
                                use_network_action=False,
                            )[-1]
                            loss = criterion(action_preds, actions)
                            total_loss += loss.item()
                            total_test_steps += 1
                    eval_statistics["test/primitive_loss {}".format(p)] = (
                        total_loss / total_test_steps)
                    if (total_loss / total_test_steps) <= best_test_loss:
                        best_test_loss = total_loss / total_test_steps
                        os.makedirs(logdir + "/models/", exist_ok=True)
                        torch.save(
                            primitive_model.state_dict(),
                            logdir + "/models/primitive_model_{}.pt".format(p),
                        )
            logger.record_dict(eval_statistics, prefix="")
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        visualize_rollout(
            env,
            None,
            None,
            world_model,
            logdir,
            max_path_length,
            use_env=True,
            forcing="none",
            tag="none",
            low_level_primitives=low_level_primitives,
            num_low_level_actions_per_primitive=
            num_low_level_actions_per_primitive,
            primitive_model=primitives,
            use_separate_primitives=True,
        )

    elif clone_primitives:
        world_model.load_state_dict(torch.load(variant["world_model_path"]))
        criterion = nn.MSELoss()
        primitive_model = Mlp(
            hidden_sizes=variant["mlp_hidden_sizes"],
            output_size=low_level_action_dim,
            input_size=world_model.feature_size +
            env.action_space.low.shape[0] + 1,
            hidden_activation=torch.nn.functional.relu,
        ).to(ptu.device)
        optimizer = optim.Adam(
            primitive_model.parameters(),
            **optimizer_kwargs,
        )
        for i in tqdm(range(num_epochs)):
            if i % variant["plotting_period"] == 0:
                visualize_rollout(
                    env,
                    None,
                    None,
                    world_model,
                    logdir,
                    max_path_length,
                    use_env=True,
                    forcing="none",
                    tag="none",
                    low_level_primitives=low_level_primitives,
                    num_low_level_actions_per_primitive=
                    num_low_level_actions_per_primitive,
                    primitive_model=primitive_model,
                )
                visualize_rollout(
                    env,
                    train_dataset.outputs,
                    train_dataset.inputs[1],
                    world_model,
                    logdir,
                    max_path_length,
                    use_env=False,
                    forcing="teacher",
                    tag="train",
                    low_level_primitives=low_level_primitives,
                    num_low_level_actions_per_primitive=
                    num_low_level_actions_per_primitive - 1,
                )
                visualize_rollout(
                    env,
                    test_dataset.outputs,
                    test_dataset.inputs[1],
                    world_model,
                    logdir,
                    max_path_length,
                    use_env=False,
                    forcing="teacher",
                    tag="test",
                    low_level_primitives=low_level_primitives,
                    num_low_level_actions_per_primitive=
                    num_low_level_actions_per_primitive - 1,
                )
            eval_statistics = OrderedDict()
            print("Epoch: ", i)
            total_loss = 0
            total_train_steps = 0
            for data in train_dataloader:
                with torch.cuda.amp.autocast():
                    (high_level_actions, obs), actions = data
                    obs = obs.to(ptu.device).float()
                    actions = actions.to(ptu.device).float()
                    high_level_actions = high_level_actions.to(
                        ptu.device).float()
                    action_preds = world_model(
                        obs,
                        (high_level_actions, actions),
                        primitive_model,
                        use_network_action=False,
                    )[-1]
                    loss = criterion(action_preds, actions)
                    total_loss += loss.item()
                    total_train_steps += 1

                update_network(primitive_model, optimizer, loss, gradient_clip,
                               scaler)
                scaler.update()
            eval_statistics[
                "train/primitive_loss"] = total_loss / total_train_steps
            best_test_loss = np.inf
            with torch.no_grad():
                total_loss = 0
                total_test_steps = 0
                for data in test_dataloader:
                    with torch.cuda.amp.autocast():
                        (high_level_actions, obs), actions = data
                        obs = obs.to(ptu.device).float()
                        actions = actions.to(ptu.device).float()
                        high_level_actions = high_level_actions.to(
                            ptu.device).float()
                        action_preds = world_model(
                            obs,
                            (high_level_actions, actions),
                            primitive_model,
                            use_network_action=False,
                        )[-1]
                        loss = criterion(action_preds, actions)
                        total_loss += loss.item()
                        total_test_steps += 1
                eval_statistics[
                    "test/primitive_loss"] = total_loss / total_test_steps
                if (total_loss / total_test_steps) <= best_test_loss:
                    best_test_loss = total_loss / total_test_steps
                    os.makedirs(logdir + "/models/", exist_ok=True)
                    torch.save(
                        primitive_model.state_dict(),
                        logdir + "/models/primitive_model.pt",
                    )
                logger.record_dict(eval_statistics, prefix="")
                logger.dump_tabular(with_prefix=False, with_timestamp=False)
    else:
        world_model = WorldModel(**world_model_kwargs).to(ptu.device)
        optimizer = optim.Adam(
            world_model.parameters(),
            **optimizer_kwargs,
        )
        for i in tqdm(range(num_epochs)):
            if i % variant["plotting_period"] == 0:
                visualize_wm(
                    env,
                    world_model,
                    train_dataset.inputs,
                    train_dataset.outputs,
                    test_dataset.inputs,
                    test_dataset.outputs,
                    logdir,
                    max_path_length,
                    low_level_primitives,
                    num_low_level_actions_per_primitive,
                )
            eval_statistics = OrderedDict()
            print("Epoch: ", i)
            total_wm_loss = 0
            total_div_loss = 0
            total_image_pred_loss = 0
            total_transition_loss = 0
            total_entropy_loss = 0
            total_train_steps = 0
            for data in train_dataloader:
                with torch.cuda.amp.autocast():
                    actions, obs = data
                    obs = obs.to(ptu.device).float()
                    actions = actions.to(ptu.device).float()
                    post, prior, post_dist, prior_dist, image_dist = world_model(
                        obs, actions)[:5]
                    obs = world_model.flatten_obs(obs.permute(
                        1, 0, 2), (int(np.prod(image_shape)), ))
                    (
                        world_model_loss,
                        div,
                        image_pred_loss,
                        transition_loss,
                        entropy_loss,
                    ) = compute_world_model_loss(
                        world_model,
                        image_shape,
                        image_dist,
                        prior,
                        post,
                        prior_dist,
                        post_dist,
                        obs,
                        **world_model_loss_kwargs,
                    )
                    total_wm_loss += world_model_loss.item()
                    total_div_loss += div.item()
                    total_image_pred_loss += image_pred_loss.item()
                    total_transition_loss += transition_loss.item()
                    total_entropy_loss += entropy_loss.item()
                    total_train_steps += 1

                update_network(world_model, optimizer, world_model_loss,
                               gradient_clip, scaler)
                scaler.update()
            eval_statistics[
                "train/wm_loss"] = total_wm_loss / total_train_steps
            eval_statistics[
                "train/div_loss"] = total_div_loss / total_train_steps
            eval_statistics["train/image_pred_loss"] = (total_image_pred_loss /
                                                        total_train_steps)
            eval_statistics["train/transition_loss"] = (total_transition_loss /
                                                        total_train_steps)
            eval_statistics["train/entropy_loss"] = (total_entropy_loss /
                                                     total_train_steps)
            best_test_loss = np.inf
            with torch.no_grad():
                total_wm_loss = 0
                total_div_loss = 0
                total_image_pred_loss = 0
                total_transition_loss = 0
                total_entropy_loss = 0
                total_train_steps = 0
                total_test_steps = 0
                for data in test_dataloader:
                    with torch.cuda.amp.autocast():
                        actions, obs = data
                        obs = obs.to(ptu.device).float()
                        actions = actions.to(ptu.device).float()
                        post, prior, post_dist, prior_dist, image_dist = world_model(
                            obs, actions)[:5]
                        obs = world_model.flatten_obs(obs.permute(
                            1, 0, 2), (int(np.prod(image_shape)), ))
                        (
                            world_model_loss,
                            div,
                            image_pred_loss,
                            transition_loss,
                            entropy_loss,
                        ) = compute_world_model_loss(
                            world_model,
                            image_shape,
                            image_dist,
                            prior,
                            post,
                            prior_dist,
                            post_dist,
                            obs,
                            **world_model_loss_kwargs,
                        )
                        total_wm_loss += world_model_loss.item()
                        total_div_loss += div.item()
                        total_image_pred_loss += image_pred_loss.item()
                        total_transition_loss += transition_loss.item()
                        total_entropy_loss += entropy_loss.item()
                        total_test_steps += 1
                eval_statistics[
                    "test/wm_loss"] = total_wm_loss / total_test_steps
                eval_statistics[
                    "test/div_loss"] = total_div_loss / total_test_steps
                eval_statistics["test/image_pred_loss"] = (
                    total_image_pred_loss / total_test_steps)
                eval_statistics["test/transition_loss"] = (
                    total_transition_loss / total_test_steps)
                eval_statistics["test/entropy_loss"] = (total_entropy_loss /
                                                        total_test_steps)
                if (total_wm_loss / total_test_steps) <= best_test_loss:
                    best_test_loss = total_wm_loss / total_test_steps
                    os.makedirs(logdir + "/models/", exist_ok=True)
                    torch.save(
                        world_model.state_dict(),
                        logdir + "/models/world_model.pt",
                    )
                logger.record_dict(eval_statistics, prefix="")
                logger.dump_tabular(with_prefix=False, with_timestamp=False)

        world_model.load_state_dict(
            torch.load(logdir + "/models/world_model.pt"))
        visualize_wm(
            env,
            world_model,
            train_dataset,
            test_dataset,
            logdir,
            max_path_length,
            low_level_primitives,
            num_low_level_actions_per_primitive,
        )
Пример #15
0
def test_trainer_save_load():
    env_kwargs = dict(
        use_image_obs=True,
        imwidth=64,
        imheight=64,
        reward_type="sparse",
        usage_kwargs=dict(
            max_path_length=5,
            use_dm_backend=True,
            use_raw_action_wrappers=False,
            unflatten_images=False,
        ),
        action_space_kwargs=dict(
            control_mode="primitives",
            action_scale=1,
            camera_settings={
                "distance": 0.38227044687537043,
                "lookat": [0.21052547, 0.32329237, 0.587819],
                "azimuth": 141.328125,
                "elevation": -53.203125160653144,
            },
        ),
    )
    actor_kwargs = dict(
        discrete_continuous_dist=True,
        init_std=0.0,
        num_layers=4,
        min_std=0.1,
        dist="tanh_normal_dreamer_v1",
    )
    vf_kwargs = dict(num_layers=3, )
    model_kwargs = dict(
        model_hidden_size=400,
        stochastic_state_size=50,
        deterministic_state_size=200,
        rssm_hidden_size=200,
        reward_num_layers=2,
        pred_discount_num_layers=3,
        gru_layer_norm=True,
        std_act="sigmoid2",
        use_prior_instead_of_posterior=False,
    )
    trainer_kwargs = dict(
        adam_eps=1e-5,
        discount=0.8,
        lam=0.95,
        forward_kl=False,
        free_nats=1.0,
        pred_discount_loss_scale=10.0,
        kl_loss_scale=0.0,
        transition_loss_scale=0.8,
        actor_lr=8e-5,
        vf_lr=8e-5,
        world_model_lr=3e-4,
        reward_loss_scale=2.0,
        use_pred_discount=True,
        policy_gradient_loss_scale=1.0,
        actor_entropy_loss_schedule="1e-4",
        target_update_period=100,
        detach_rewards=False,
        imagination_horizon=5,
    )
    env_suite = "metaworld"
    env_name = "disassemble-v2"
    eval_envs = [make_env(env_suite, env_name, env_kwargs) for _ in range(1)]
    eval_env = DummyVecEnv(eval_envs, )

    discrete_continuous_dist = True
    continuous_action_dim = eval_envs[0].max_arg_len
    discrete_action_dim = eval_envs[0].num_primitives
    if not discrete_continuous_dist:
        continuous_action_dim = continuous_action_dim + discrete_action_dim
        discrete_action_dim = 0
    action_dim = continuous_action_dim + discrete_action_dim

    world_model = WorldModel(
        action_dim,
        image_shape=eval_envs[0].image_shape,
        **model_kwargs,
    )
    actor = ActorModel(
        model_kwargs["model_hidden_size"],
        world_model.feature_size,
        hidden_activation=nn.ELU,
        discrete_action_dim=discrete_action_dim,
        continuous_action_dim=continuous_action_dim,
        **actor_kwargs,
    )

    vf = Mlp(
        hidden_sizes=[model_kwargs["model_hidden_size"]] *
        vf_kwargs["num_layers"],
        output_size=1,
        input_size=world_model.feature_size,
        hidden_activation=nn.ELU,
    )
    target_vf = Mlp(
        hidden_sizes=[model_kwargs["model_hidden_size"]] *
        vf_kwargs["num_layers"],
        output_size=1,
        input_size=world_model.feature_size,
        hidden_activation=nn.ELU,
    )

    trainer = DreamerV2Trainer(
        actor,
        vf,
        target_vf,
        world_model,
        eval_envs[0].image_shape,
        **trainer_kwargs,
    )

    with tempfile.TemporaryDirectory() as tmpdirname:
        trainer.save(tmpdirname, "trainer.pkl")
        trainer = DreamerV2Trainer(
            actor,
            vf,
            target_vf,
            world_model,
            eval_envs[0].image_shape,
            **trainer_kwargs,
        )
        new_trainer = trainer.load(tmpdirname, "trainer.pkl")
Пример #16
0
def experiment(variant):
    import os
    import os.path as osp

    os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"

    import torch
    import torch.nn as nn

    import rlkit.torch.pytorch_util as ptu
    from rlkit.core import logger
    from rlkit.envs.primitives_make_env import make_env
    from rlkit.envs.wrappers.mujoco_vec_wrappers import (
        DummyVecEnv,
        StableBaselinesVecEnv,
    )
    from rlkit.torch.model_based.dreamer.actor_models import ActorModel
    from rlkit.torch.model_based.dreamer.dreamer_policy import (
        ActionSpaceSamplePolicy,
        DreamerLowLevelRAPSPolicy,
    )
    from rlkit.torch.model_based.dreamer.dreamer_v2 import DreamerV2LowLevelRAPSTrainer
    from rlkit.torch.model_based.dreamer.episode_replay_buffer import (
        EpisodeReplayBufferLowLevelRAPS, )
    from rlkit.torch.model_based.dreamer.mlp import Mlp
    from rlkit.torch.model_based.dreamer.path_collector import VecMdpPathCollector
    from rlkit.torch.model_based.dreamer.rollout_functions import (
        vec_rollout_low_level_raps, )
    from rlkit.torch.model_based.dreamer.visualization import (
        post_epoch_visualize_func,
        visualize_primitive_unsubsampled_rollout,
    )
    from rlkit.torch.model_based.dreamer.world_models import LowlevelRAPSWorldModel
    from rlkit.torch.model_based.rl_algorithm import TorchBatchRLAlgorithm

    env_suite = variant.get("env_suite", "kitchen")
    env_kwargs = variant["env_kwargs"]
    num_expl_envs = variant["num_expl_envs"]
    num_low_level_actions_per_primitive = variant[
        "num_low_level_actions_per_primitive"]
    low_level_action_dim = variant["low_level_action_dim"]

    print("MAKING ENVS")
    env_name = variant["env_name"]
    make_env_lambda = lambda: make_env(env_suite, env_name, env_kwargs)

    if num_expl_envs > 1:
        env_fns = [make_env_lambda for _ in range(num_expl_envs)]
        expl_env = StableBaselinesVecEnv(
            env_fns=env_fns,
            start_method="fork",
            reload_state_args=(
                num_expl_envs,
                make_env,
                (env_suite, env_name, env_kwargs),
            ),
        )
    else:
        expl_envs = [make_env_lambda()]
        expl_env = DummyVecEnv(expl_envs,
                               pass_render_kwargs=variant.get(
                                   "pass_render_kwargs", False))
    eval_envs = [make_env_lambda() for _ in range(1)]
    eval_env = DummyVecEnv(eval_envs,
                           pass_render_kwargs=variant.get(
                               "pass_render_kwargs", False))

    discrete_continuous_dist = variant["actor_kwargs"][
        "discrete_continuous_dist"]
    num_primitives = eval_envs[0].num_primitives
    continuous_action_dim = eval_envs[0].max_arg_len
    discrete_action_dim = num_primitives
    if not discrete_continuous_dist:
        continuous_action_dim = continuous_action_dim + discrete_action_dim
        discrete_action_dim = 0
    action_dim = continuous_action_dim + discrete_action_dim
    obs_dim = expl_env.observation_space.low.size

    primitive_model = Mlp(
        output_size=variant["low_level_action_dim"],
        input_size=variant["model_kwargs"]["stochastic_state_size"] +
        variant["model_kwargs"]["deterministic_state_size"] +
        eval_env.envs[0].action_space.low.shape[0] + 1,
        hidden_activation=nn.ReLU,
        num_embeddings=eval_envs[0].num_primitives,
        embedding_dim=eval_envs[0].num_primitives,
        embedding_slice=eval_envs[0].num_primitives,
        **variant["primitive_model_kwargs"],
    )
    world_model = LowlevelRAPSWorldModel(
        low_level_action_dim,
        image_shape=eval_envs[0].image_shape,
        primitive_model=primitive_model,
        **variant["model_kwargs"],
    )
    actor = ActorModel(
        variant["model_kwargs"]["model_hidden_size"],
        world_model.feature_size,
        hidden_activation=nn.ELU,
        discrete_action_dim=discrete_action_dim,
        continuous_action_dim=continuous_action_dim,
        **variant["actor_kwargs"],
    )
    vf = Mlp(
        hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] *
        variant["vf_kwargs"]["num_layers"],
        output_size=1,
        input_size=world_model.feature_size,
        hidden_activation=nn.ELU,
    )
    target_vf = Mlp(
        hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] *
        variant["vf_kwargs"]["num_layers"],
        output_size=1,
        input_size=world_model.feature_size,
        hidden_activation=nn.ELU,
    )

    if variant.get("models_path", None) is not None:
        filename = variant["models_path"]
        actor.load_state_dict(torch.load(osp.join(filename, "actor.ptc")))
        vf.load_state_dict(torch.load(osp.join(filename, "vf.ptc")))
        target_vf.load_state_dict(
            torch.load(osp.join(filename, "target_vf.ptc")))
        world_model.load_state_dict(
            torch.load(osp.join(filename, "world_model.ptc")))
        print("LOADED MODELS")

    expl_policy = DreamerLowLevelRAPSPolicy(
        world_model,
        actor,
        obs_dim,
        action_dim,
        num_low_level_actions_per_primitive=num_low_level_actions_per_primitive,
        low_level_action_dim=low_level_action_dim,
        exploration=True,
        expl_amount=variant.get("expl_amount", 0.3),
        discrete_action_dim=discrete_action_dim,
        continuous_action_dim=continuous_action_dim,
        discrete_continuous_dist=discrete_continuous_dist,
    )
    eval_policy = DreamerLowLevelRAPSPolicy(
        world_model,
        actor,
        obs_dim,
        action_dim,
        num_low_level_actions_per_primitive=num_low_level_actions_per_primitive,
        low_level_action_dim=low_level_action_dim,
        exploration=False,
        expl_amount=0.0,
        discrete_action_dim=discrete_action_dim,
        continuous_action_dim=continuous_action_dim,
        discrete_continuous_dist=discrete_continuous_dist,
    )

    initial_data_collection_policy = ActionSpaceSamplePolicy(expl_env)

    rollout_function_kwargs = dict(
        num_low_level_actions_per_primitive=num_low_level_actions_per_primitive,
        low_level_action_dim=low_level_action_dim,
        num_primitives=num_primitives,
    )

    expl_path_collector = VecMdpPathCollector(
        expl_env,
        expl_policy,
        save_env_in_snapshot=False,
        rollout_fn=vec_rollout_low_level_raps,
        rollout_function_kwargs=rollout_function_kwargs,
    )

    eval_path_collector = VecMdpPathCollector(
        eval_env,
        eval_policy,
        save_env_in_snapshot=False,
        rollout_fn=vec_rollout_low_level_raps,
        rollout_function_kwargs=rollout_function_kwargs,
    )

    replay_buffer = EpisodeReplayBufferLowLevelRAPS(
        num_expl_envs, obs_dim, action_dim, **variant["replay_buffer_kwargs"])
    filename = variant.get("replay_buffer_path", None)
    if filename is not None:
        replay_buffer.load_buffer(filename, eval_env.envs[0].num_primitives)
    eval_filename = variant.get("eval_buffer_path", None)
    if eval_filename is not None:
        eval_buffer = EpisodeReplayBufferLowLevelRAPS(
            1000,
            expl_env,
            variant["algorithm_kwargs"]["max_path_length"],
            num_low_level_actions_per_primitive,
            obs_dim,
            action_dim,
            low_level_action_dim,
            replace=False,
        )
        eval_buffer.load_buffer(eval_filename, eval_env.envs[0].num_primitives)
    else:
        eval_buffer = None

    trainer = DreamerV2LowLevelRAPSTrainer(
        actor,
        vf,
        target_vf,
        world_model,
        eval_envs[0].image_shape,
        **variant["trainer_kwargs"],
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        pretrain_policy=initial_data_collection_policy,
        **variant["algorithm_kwargs"],
        eval_buffer=eval_buffer,
    )
    algorithm.low_level_primitives = True
    if variant.get("generate_video", False):
        post_epoch_visualize_func(algorithm, 0)
    elif variant.get("unsubsampled_rollout", False):
        visualize_primitive_unsubsampled_rollout(
            make_env_lambda(),
            make_env_lambda(),
            make_env_lambda(),
            logger.get_snapshot_dir(),
            algorithm.max_path_length,
            num_low_level_actions_per_primitive,
            policy=eval_policy,
            img_size=64,
            num_rollouts=4,
        )
    else:
        if variant.get("save_video", False):
            algorithm.post_epoch_funcs.append(post_epoch_visualize_func)
        print("TRAINING")
        algorithm.to(ptu.device)
        algorithm.train()
        if variant.get("save_video", False):
            post_epoch_visualize_func(algorithm, -1)
Пример #17
0
        },
        usage_kwargs=dict(
            use_dm_backend=True,
            use_raw_action_wrappers=False,
            use_image_obs=True,
            max_path_length=5,
            unflatten_images=False,
        ),
        image_kwargs=dict(imwidth=64, imheight=64),
        collect_primitives_info=True,
        include_phase_variable=True,
    )
    env_suite = "metaworld"
    env_name = "reach-v2"

    env = make_env(env_suite, env_name, env_kwargs)

    file_path = osp.join("data/" + args.logdir + "/test.avi")

    a1 = env.action_space.sample()
    a1[0] = 100
    obs = env.reset()
    o, r, d, i = env.step(
        a1,
        render_every_step=True,
        render_mode="rgb_array",
        render_im_shape=(480, 480),
    )

    true_actions1 = np.array(i["actions"])
    true_states1 = np.array(i["robot-states"])
Пример #18
0
def experiment(variant):
    gym.logger.set_level(40)
    work_dir = rlkit_logger.get_snapshot_dir()
    args = parse_args()
    seed = int(variant["seed"])
    utils.set_seed_everywhere(seed)
    os.makedirs(work_dir, exist_ok=True)
    agent_kwargs = variant["agent_kwargs"]
    data_augs = agent_kwargs["data_augs"]
    encoder_type = agent_kwargs["encoder_type"]
    discrete_continuous_dist = agent_kwargs["discrete_continuous_dist"]

    env_suite = variant["env_suite"]
    env_name = variant["env_name"]
    env_kwargs = variant["env_kwargs"]
    pre_transform_image_size = variant["pre_transform_image_size"]
    image_size = variant["image_size"]
    frame_stack = variant["frame_stack"]
    batch_size = variant["batch_size"]
    replay_buffer_capacity = variant["replay_buffer_capacity"]
    num_train_steps = variant["num_train_steps"]
    num_eval_episodes = variant["num_eval_episodes"]
    eval_freq = variant["eval_freq"]
    action_repeat = variant["action_repeat"]
    init_steps = variant["init_steps"]
    log_interval = variant["log_interval"]
    use_raw_actions = variant["use_raw_actions"]
    pre_transform_image_size = (
        pre_transform_image_size if "crop" in data_augs else image_size
    )
    pre_transform_image_size = pre_transform_image_size

    if data_augs == "crop":
        pre_transform_image_size = 100
        image_size = image_size
    elif data_augs == "translate":
        pre_transform_image_size = 100
        image_size = 108

    if env_suite == 'kitchen':
        env_kwargs['imwidth'] = pre_transform_image_size
        env_kwargs['imheight'] = pre_transform_image_size
    else:
        env_kwargs['image_kwargs']['imwidth'] = pre_transform_image_size
        env_kwargs['image_kwargs']['imheight'] = pre_transform_image_size

    expl_env = primitives_make_env.make_env(env_suite, env_name, env_kwargs)
    eval_env = primitives_make_env.make_env(env_suite, env_name, env_kwargs)
    # stack several consecutive frames together
    if encoder_type == "pixel":
        expl_env = utils.FrameStack(expl_env, k=frame_stack)
        eval_env = utils.FrameStack(eval_env, k=frame_stack)

    # make directory
    ts = time.gmtime()
    ts = time.strftime("%m-%d", ts)
    env_name = env_name
    exp_name = (
        env_name
        + "-"
        + ts
        + "-im"
        + str(image_size)
        + "-b"
        + str(batch_size)
        + "-s"
        + str(seed)
        + "-"
        + encoder_type
    )
    work_dir = work_dir + "/" + exp_name

    utils.make_dir(work_dir)
    video_dir = utils.make_dir(os.path.join(work_dir, "video"))
    model_dir = utils.make_dir(os.path.join(work_dir, "model"))
    buffer_dir = utils.make_dir(os.path.join(work_dir, "buffer"))

    video = VideoRecorder(video_dir if args.save_video else None)

    with open(os.path.join(work_dir, "args.json"), "w") as f:
        json.dump(vars(args), f, sort_keys=True, indent=4)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if use_raw_actions:
        continuous_action_dim = expl_env.action_space.low.size
        discrete_action_dim = 0
    else:
        num_primitives = expl_env.num_primitives
        max_arg_len = expl_env.max_arg_len
        if discrete_continuous_dist:
            continuous_action_dim = max_arg_len
            discrete_action_dim = num_primitives
        else:
            continuous_action_dim = max_arg_len + num_primitives
            discrete_action_dim = 0

    if encoder_type == "pixel":
        obs_shape = (3 * frame_stack, image_size, image_size)
        pre_aug_obs_shape = (
            3 * frame_stack,
            pre_transform_image_size,
            pre_transform_image_size,
        )
    else:
        obs_shape = env.observation_space.shape
        pre_aug_obs_shape = obs_shape

    replay_buffer = utils.ReplayBuffer(
        obs_shape=pre_aug_obs_shape,
        action_size=continuous_action_dim + discrete_action_dim,
        capacity=replay_buffer_capacity,
        batch_size=batch_size,
        device=device,
        image_size=image_size,
        pre_image_size=pre_transform_image_size,
    )

    agent = make_agent(
        obs_shape=obs_shape,
        continuous_action_dim=continuous_action_dim,
        discrete_action_dim=discrete_action_dim,
        args=args,
        device=device,
        agent_kwargs=agent_kwargs,
    )

    L = Logger(work_dir, use_tb=args.save_tb)

    episode, episode_reward, done = 0, 0, True
    start_time = time.time()
    epoch_start_time = time.time()
    train_expl_st = time.time()
    total_train_expl_time = 0
    all_infos = []
    ep_infos = []
    num_train_calls = 0
    for step in range(num_train_steps):
        # evaluate agent periodically

        if step % eval_freq == 0:
            total_train_expl_time += time.time()-train_expl_st
            L.log("eval/episode", episode, step)
            evaluate(
                eval_env,
                agent,
                video,
                num_eval_episodes,
                L,
                step,
                encoder_type,
                data_augs,
                image_size,
                pre_transform_image_size,
                env_name,
                action_repeat,
                work_dir,
                seed,
            )
            if args.save_model:
                agent.save_curl(model_dir, step)
            if args.save_buffer:
                replay_buffer.save(buffer_dir)
            train_expl_st = time.time()
        if done:
            if step > 0:
                if step % log_interval == 0:
                    L.log("train/duration", time.time() - epoch_start_time, step)
                    L.dump(step)
            if step % log_interval == 0:
                L.log("train/episode_reward", episode_reward, step)
            obs = expl_env.reset()
            done = False
            episode_reward = 0
            episode_step = 0
            episode += 1
            if step % log_interval == 0:
                all_infos.append(ep_infos)

                L.log("train/episode", episode, step)
                statistics = compute_path_info(all_infos)

                rlkit_logger.record_dict(statistics, prefix="exploration/")
                rlkit_logger.record_tabular(
                    "time/epoch (s)", time.time() - epoch_start_time
                )
                rlkit_logger.record_tabular("time/total (s)", time.time() - start_time)
                rlkit_logger.record_tabular("time/training and exploration (s)", total_train_expl_time)
                rlkit_logger.record_tabular("trainer/num train calls", num_train_calls)
                rlkit_logger.record_tabular("exploration/num steps total", step)
                rlkit_logger.record_tabular("Epoch", step // log_interval)
                rlkit_logger.dump_tabular(with_prefix=False, with_timestamp=False)
                all_infos = []
                epoch_start_time = time.time()
            ep_infos = []


        # sample action for data collection
        if step < init_steps:
            action = expl_env.action_space.sample()
        else:
            with utils.eval_mode(agent):
                action = agent.sample_action(obs / 255.0)

        # run training update
        if step >= init_steps:
            num_updates = 1
            for _ in range(num_updates):
                agent.update(replay_buffer, L, step)
                num_train_calls += 1

        next_obs, reward, done, info = expl_env.step(action)
        ep_infos.append(info)
        # allow infinit bootstrap
        done_bool = (
            0 if episode_step + 1 == expl_env._max_episode_steps else float(done)
        )
        episode_reward += reward
        replay_buffer.add(obs, action, reward, next_obs, done_bool)

        obs = next_obs
        episode_step += 1