def test_run_hinge_success(): env_suite = "kitchen" env_name = "hinge_cabinet" env_kwargs = dict( reward_type="sparse", use_image_obs=True, action_scale=1.4, use_workspace_limits=True, control_mode="primitives", usage_kwargs=dict( use_dm_backend=True, use_raw_action_wrappers=False, unflatten_images=False, max_path_length=5, ), action_space_kwargs=dict(), ) env = make_env( env_suite, env_name, env_kwargs, ) env.reset() ctr = 0 max_path_length = 5 for _ in range(max_path_length): a = np.zeros(env.action_space.low.size) if ctr % max_path_length == 0: env.reset() a[env.get_idx_from_primitive_name("lift")] = 1 a[env.num_primitives + env.primitive_name_to_action_idx["lift"]] = 1 if ctr % max_path_length == 1: a[env.get_idx_from_primitive_name("angled_x_y_grasp")] = 1 a[env.num_primitives + np.array(env.primitive_name_to_action_idx["angled_x_y_grasp"] )] = np.array([-np.pi / 6, -0.3, 1.4, 0]) if ctr % max_path_length == 2: a[env.get_idx_from_primitive_name("move_delta_ee_pose")] = 1 a[env.num_primitives + np.array(env.primitive_name_to_action_idx["move_delta_ee_pose"] )] = np.array(np.array([0.5, -1, 0])) if ctr % max_path_length == 3: a[env.get_idx_from_primitive_name("rotate_about_x_axis")] = 1 a[env.num_primitives + np.array(env.primitive_name_to_action_idx["rotate_about_x_axis"] )] = np.array([ 1, ]) if ctr % max_path_length == 4: a[env.get_idx_from_primitive_name("rotate_about_x_axis")] = 1 a[env.num_primitives + np.array(env.primitive_name_to_action_idx["rotate_about_x_axis"] )] = np.array([ 0, ]) o, r, d, i = env.step(a / 1.4, ) ctr += 1 assert r == 1.0
def test_light_switch_success(): env_suite = "kitchen" env_name = "light_switch" env_kwargs = dict( reward_type="sparse", use_image_obs=True, action_scale=1.4, use_workspace_limits=True, control_mode="primitives", usage_kwargs=dict( use_dm_backend=True, use_raw_action_wrappers=False, unflatten_images=False, max_path_length=5, ), action_space_kwargs=dict(), ) env = make_env( env_suite, env_name, env_kwargs, ) env.reset() max_path_length = 5 ctr = 0 for i in range(max_path_length): a = np.zeros(env.action_space.low.size) if ctr % max_path_length == 0: a[env.get_idx_from_primitive_name("close_gripper")] = 1 a[ env.num_primitives + np.array(env.primitive_name_to_action_idx["lift"]) ] = 1 if ctr % max_path_length == 1: a[env.get_idx_from_primitive_name("lift")] = 1 a[ env.num_primitives + np.array(env.primitive_name_to_action_idx["lift"]) ] = 0.6 if ctr % max_path_length == 2: a[env.get_idx_from_primitive_name("move_right")] = 1 a[ env.num_primitives + env.primitive_name_to_action_idx["move_right"] ] = 0.45 if ctr % max_path_length == 3: a[env.get_idx_from_primitive_name("move_forward")] = 1 a[ env.num_primitives + env.primitive_name_to_action_idx["move_forward"] ] = 1.25 if ctr % max_path_length == 4: a[env.get_idx_from_primitive_name("move_left")] = 1 a[env.num_primitives + env.primitive_name_to_action_idx["move_left"]] = 0.45 o, r, d, _ = env.step( a / 1.4, ) ctr += 1 assert r == 1.0
def reset(self): if hasattr(self, "env"): del self.env gc.collect() self.idx = self.num_resets % self.num_multitask_envs env_name = self.env_names[self.idx] self.env = primitives_make_env.make_env(self.env_suite, env_name, self.env_kwargs) o = self.env.reset() self.num_resets += 1 o = np.concatenate((o, self.get_one_hot(self.idx))) return o
def test_run_kettle_success(): env_suite = "kitchen" env_name = "kettle" env_kwargs = dict( reward_type="sparse", use_image_obs=True, action_scale=1.4, use_workspace_limits=True, control_mode="primitives", usage_kwargs=dict( use_dm_backend=True, use_raw_action_wrappers=False, unflatten_images=False, max_path_length=5, ), action_space_kwargs=dict(), ) env = make_env( env_suite, env_name, env_kwargs, ) env.reset() ctr = 0 max_path_length = 5 for i in range(max_path_length): a = np.zeros(env.action_space.low.size) if ctr % max_path_length == 0: env.reset() a[env.get_idx_from_primitive_name("drop")] = 1 a[env.num_primitives + env.primitive_name_to_action_idx["drop"]] = 0.5 if ctr % max_path_length == 1: a[env.get_idx_from_primitive_name("angled_x_y_grasp")] = 1 a[env.num_primitives + np.array(env.primitive_name_to_action_idx["angled_x_y_grasp"] )] = np.array([0, 0.15, 0.7, 1]) if ctr % max_path_length == 2: a[env.get_idx_from_primitive_name("move_delta_ee_pose")] = 1 a[env.num_primitives + np.array(env.primitive_name_to_action_idx["move_delta_ee_pose"] )] = np.array([0.25, 1.0, 0.25]) if ctr % max_path_length == 3: a[env.get_idx_from_primitive_name("drop")] = 1 a[env.num_primitives + env.primitive_name_to_action_idx["drop"]] = 0.25 if ctr % max_path_length == 4: a[env.get_idx_from_primitive_name("open_gripper")] = 1 a[env.num_primitives + env.primitive_name_to_action_idx["open_gripper"]] = 1 o, r, d, _ = env.step(a / 1.4) ctr += 1 assert r == 1.0
def test_top_burner_success(): env_suite = "kitchen" env_name = "top_left_burner" env_kwargs = dict( reward_type="sparse", use_image_obs=True, action_scale=1.4, use_workspace_limits=True, control_mode="primitives", usage_kwargs=dict( use_dm_backend=True, use_raw_action_wrappers=False, unflatten_images=False, max_path_length=5, ), action_space_kwargs=dict(), ) env = make_env( env_suite, env_name, env_kwargs, ) env.reset() ctr = 0 max_path_length = 3 for i in range(max_path_length): a = np.zeros(env.action_space.low.size) if ctr % max_path_length == 0: env.reset() a[env.get_idx_from_primitive_name("lift")] = 1 a[env.num_primitives + env.primitive_name_to_action_idx["lift"]] = 0.6 if ctr % max_path_length == 1: a[env.get_idx_from_primitive_name("angled_x_y_grasp")] = 1 a[env.num_primitives + np.array(env.primitive_name_to_action_idx["angled_x_y_grasp"] )] = np.array([0, 0.5, 1, 1]) if ctr % max_path_length == 2: a[env.get_idx_from_primitive_name("rotate_about_y_axis")] = 1 a[env.num_primitives + env.primitive_name_to_action_idx["rotate_about_y_axis"]] = ( -np.pi / 4) o, r, d, _ = env.step(a / 1.4, ) ctr += 1 assert r == 1.0
def test_run_microwave_success(): env_suite = "kitchen" env_name = "microwave" env_kwargs = dict( reward_type="sparse", use_image_obs=True, action_scale=1.4, use_workspace_limits=True, control_mode="primitives", usage_kwargs=dict( use_dm_backend=True, use_raw_action_wrappers=False, unflatten_images=False, max_path_length=5, ), action_space_kwargs=dict(), ) env = make_env( env_suite, env_name, env_kwargs, ) env.reset() ctr = 0 max_path_length = 3 for i in range(3): a = np.zeros(env.action_space.low.size) if ctr % max_path_length == 0: env.reset() a[env.get_idx_from_primitive_name("drop")] = 1 a[env.num_primitives + env.primitive_name_to_action_idx["drop"]] = 0.55 if ctr % max_path_length == 1: a[env.get_idx_from_primitive_name("angled_x_y_grasp")] = 1 a[env.num_primitives + np.array(env.primitive_name_to_action_idx["angled_x_y_grasp"] )] = np.array([-np.pi / 6, -0.3, 0.95, 1]) if ctr % max_path_length == 2: a[env.get_idx_from_primitive_name("move_backward")] = 1 a[env.num_primitives + env.primitive_name_to_action_idx["move_backward"]] = 0.6 o, r, d, _ = env.step(a / 1.4, ) ctr += 1 assert r == 1.0
def test_dummy_vec_env_save_load(): env_kwargs = dict( use_image_obs=True, imwidth=64, imheight=64, reward_type="sparse", usage_kwargs=dict( max_path_length=5, use_dm_backend=True, use_raw_action_wrappers=False, unflatten_images=False, ), action_space_kwargs=dict( control_mode="primitives", action_scale=1, camera_settings={ "distance": 0.38227044687537043, "lookat": [0.21052547, 0.32329237, 0.587819], "azimuth": 141.328125, "elevation": -53.203125160653144, }, ), ) env_suite = "metaworld" env_name = "disassemble-v2" make_env_lambda = lambda: make_env(env_suite, env_name, env_kwargs) n_envs = 2 envs = [make_env_lambda() for _ in range(n_envs)] env = DummyVecEnv( envs, ) with tempfile.TemporaryDirectory() as tmpdirname: env.save(tmpdirname, "env.pkl") env = DummyVecEnv( envs[0:1], ) new_env = env.load(tmpdirname, "env.pkl") assert new_env.n_envs == n_envs
def test_path_collector_save_load(): env_kwargs = dict( use_image_obs=True, imwidth=64, imheight=64, reward_type="sparse", usage_kwargs=dict( max_path_length=5, use_dm_backend=True, use_raw_action_wrappers=False, unflatten_images=False, ), action_space_kwargs=dict( control_mode="primitives", action_scale=1, camera_settings={ "distance": 0.38227044687537043, "lookat": [0.21052547, 0.32329237, 0.587819], "azimuth": 141.328125, "elevation": -53.203125160653144, }, ), ) actor_kwargs = dict( discrete_continuous_dist=True, init_std=0.0, num_layers=4, min_std=0.1, dist="tanh_normal_dreamer_v1", ) model_kwargs = dict( model_hidden_size=400, stochastic_state_size=50, deterministic_state_size=200, rssm_hidden_size=200, reward_num_layers=2, pred_discount_num_layers=3, gru_layer_norm=True, std_act="sigmoid2", use_prior_instead_of_posterior=False, ) env_suite = "metaworld" env_name = "disassemble-v2" eval_envs = [make_env(env_suite, env_name, env_kwargs) for _ in range(1)] eval_env = DummyVecEnv(eval_envs, ) discrete_continuous_dist = True continuous_action_dim = eval_envs[0].max_arg_len discrete_action_dim = eval_envs[0].num_primitives if not discrete_continuous_dist: continuous_action_dim = continuous_action_dim + discrete_action_dim discrete_action_dim = 0 action_dim = continuous_action_dim + discrete_action_dim obs_dim = eval_env.observation_space.low.size world_model = WorldModel( action_dim, image_shape=eval_envs[0].image_shape, **model_kwargs, ) actor = ActorModel( model_kwargs["model_hidden_size"], world_model.feature_size, hidden_activation=nn.ELU, discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, **actor_kwargs, ) eval_policy = DreamerPolicy( world_model, actor, obs_dim, action_dim, exploration=False, expl_amount=0.0, discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, discrete_continuous_dist=discrete_continuous_dist, ) eval_path_collector = VecMdpPathCollector( eval_env, eval_policy, save_env_in_snapshot=False, ) with tempfile.TemporaryDirectory() as tmpdirname: eval_path_collector.save(tmpdirname, "path_collector.pkl") eval_path_collector = VecMdpPathCollector( eval_env, eval_policy, save_env_in_snapshot=False, ) new_path_collector = eval_path_collector.load(tmpdirname, "path_collector.pkl")
def experiment(variant): import os import os.path as osp os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1" import torch import torch.nn as nn import rlkit.envs.primitives_make_env as primitives_make_env import rlkit.torch.pytorch_util as ptu from rlkit.envs.wrappers.mujoco_vec_wrappers import ( DummyVecEnv, StableBaselinesVecEnv, ) from rlkit.torch.model_based.dreamer.actor_models import ActorModel from rlkit.torch.model_based.dreamer.dreamer_policy import ( ActionSpaceSamplePolicy, DreamerPolicy, ) from rlkit.torch.model_based.dreamer.dreamer_v2 import DreamerV2Trainer from rlkit.torch.model_based.dreamer.episode_replay_buffer import ( EpisodeReplayBuffer, EpisodeReplayBufferLowLevelRAPS, ) from rlkit.torch.model_based.dreamer.mlp import Mlp from rlkit.torch.model_based.dreamer.path_collector import VecMdpPathCollector from rlkit.torch.model_based.dreamer.visualization import post_epoch_visualize_func from rlkit.torch.model_based.dreamer.world_models import WorldModel from rlkit.torch.model_based.rl_algorithm import TorchBatchRLAlgorithm env_suite = variant.get("env_suite", "kitchen") env_name = variant["env_name"] env_kwargs = variant["env_kwargs"] use_raw_actions = variant["use_raw_actions"] num_expl_envs = variant["num_expl_envs"] if num_expl_envs > 1: env_fns = [ lambda: primitives_make_env.make_env( env_suite, env_name, env_kwargs) for _ in range(num_expl_envs) ] expl_env = StableBaselinesVecEnv( env_fns=env_fns, start_method="fork", reload_state_args=( num_expl_envs, primitives_make_env.make_env, (env_suite, env_name, env_kwargs), ), ) else: expl_envs = [ primitives_make_env.make_env(env_suite, env_name, env_kwargs) ] expl_env = DummyVecEnv(expl_envs, pass_render_kwargs=variant.get( "pass_render_kwargs", False)) eval_envs = [ primitives_make_env.make_env(env_suite, env_name, env_kwargs) for _ in range(1) ] eval_env = DummyVecEnv(eval_envs, pass_render_kwargs=variant.get( "pass_render_kwargs", False)) if use_raw_actions: discrete_continuous_dist = False continuous_action_dim = eval_env.action_space.low.size discrete_action_dim = 0 use_batch_length = True action_dim = continuous_action_dim else: discrete_continuous_dist = variant["actor_kwargs"][ "discrete_continuous_dist"] continuous_action_dim = eval_envs[0].max_arg_len discrete_action_dim = eval_envs[0].num_primitives if not discrete_continuous_dist: continuous_action_dim = continuous_action_dim + discrete_action_dim discrete_action_dim = 0 action_dim = continuous_action_dim + discrete_action_dim use_batch_length = False obs_dim = expl_env.observation_space.low.size world_model = WorldModel( action_dim, image_shape=eval_envs[0].image_shape, **variant["model_kwargs"], ) actor = ActorModel( variant["model_kwargs"]["model_hidden_size"], world_model.feature_size, hidden_activation=nn.ELU, discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, **variant["actor_kwargs"], ) vf = Mlp( hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] * variant["vf_kwargs"]["num_layers"], output_size=1, input_size=world_model.feature_size, hidden_activation=nn.ELU, ) target_vf = Mlp( hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] * variant["vf_kwargs"]["num_layers"], output_size=1, input_size=world_model.feature_size, hidden_activation=nn.ELU, ) if variant.get("models_path", None) is not None: filename = variant["models_path"] actor.load_state_dict(torch.load(osp.join(filename, "actor.ptc"))) vf.load_state_dict(torch.load(osp.join(filename, "vf.ptc"))) target_vf.load_state_dict( torch.load(osp.join(filename, "target_vf.ptc"))) world_model.load_state_dict( torch.load(osp.join(filename, "world_model.ptc"))) print("LOADED MODELS") expl_policy = DreamerPolicy( world_model, actor, obs_dim, action_dim, exploration=True, expl_amount=variant.get("expl_amount", 0.3), discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, discrete_continuous_dist=discrete_continuous_dist, ) eval_policy = DreamerPolicy( world_model, actor, obs_dim, action_dim, exploration=False, expl_amount=0.0, discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, discrete_continuous_dist=discrete_continuous_dist, ) rand_policy = ActionSpaceSamplePolicy(expl_env) expl_path_collector = VecMdpPathCollector( expl_env, expl_policy, save_env_in_snapshot=False, ) eval_path_collector = VecMdpPathCollector( eval_env, eval_policy, save_env_in_snapshot=False, ) variant["replay_buffer_kwargs"]["use_batch_length"] = use_batch_length replay_buffer = EpisodeReplayBuffer( num_expl_envs, obs_dim, action_dim, **variant["replay_buffer_kwargs"], ) eval_filename = variant.get("eval_buffer_path", None) if eval_filename is not None: eval_buffer = EpisodeReplayBufferLowLevelRAPS( 1000, expl_env, variant["algorithm_kwargs"]["max_path_length"], 10, obs_dim, action_dim, 9, replace=False, ) eval_buffer.load_buffer(eval_filename, eval_env.envs[0].num_primitives) else: eval_buffer = None trainer = DreamerV2Trainer( actor, vf, target_vf, world_model, eval_envs[0].image_shape, **variant["trainer_kwargs"], ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, pretrain_policy=rand_policy, eval_buffer=eval_buffer, **variant["algorithm_kwargs"], ) algorithm.low_level_primitives = False if variant.get("generate_video", False): post_epoch_visualize_func(algorithm, 0) else: if variant.get("save_video", False): algorithm.post_epoch_funcs.append(post_epoch_visualize_func) print("TRAINING") algorithm.to(ptu.device) algorithm.train() if variant.get("save_video", False): post_epoch_visualize_func(algorithm, -1)
def run_trained_policy(path): ptu.set_gpu_mode(True) variant = json.load(open(osp.join(path, "variant.json"), "r")) set_seed(variant["seed"]) variant = preprocess_variant_llraps(variant) env_suite = variant.get("env_suite", "kitchen") env_kwargs = variant["env_kwargs"] num_low_level_actions_per_primitive = variant[ "num_low_level_actions_per_primitive"] low_level_action_dim = variant["low_level_action_dim"] env_name = variant["env_name"] make_env_lambda = lambda: make_env(env_suite, env_name, env_kwargs) eval_envs = [make_env_lambda() for _ in range(1)] eval_env = DummyVecEnv(eval_envs, pass_render_kwargs=variant.get( "pass_render_kwargs", False)) discrete_continuous_dist = variant["actor_kwargs"][ "discrete_continuous_dist"] num_primitives = eval_envs[0].num_primitives continuous_action_dim = eval_envs[0].max_arg_len discrete_action_dim = num_primitives if not discrete_continuous_dist: continuous_action_dim = continuous_action_dim + discrete_action_dim discrete_action_dim = 0 action_dim = continuous_action_dim + discrete_action_dim obs_dim = eval_env.observation_space.low.size primitive_model = Mlp( output_size=variant["low_level_action_dim"], input_size=variant["model_kwargs"]["stochastic_state_size"] + variant["model_kwargs"]["deterministic_state_size"] + eval_env.envs[0].action_space.low.shape[0] + 1, hidden_activation=nn.ReLU, num_embeddings=eval_envs[0].num_primitives, embedding_dim=eval_envs[0].num_primitives, embedding_slice=eval_envs[0].num_primitives, **variant["primitive_model_kwargs"], ) world_model = LowlevelRAPSWorldModel( low_level_action_dim, image_shape=eval_envs[0].image_shape, primitive_model=primitive_model, **variant["model_kwargs"], ) actor = ActorModel( variant["model_kwargs"]["model_hidden_size"], world_model.feature_size, hidden_activation=nn.ELU, discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, **variant["actor_kwargs"], ) actor.load_state_dict(torch.load(osp.join(path, "actor.ptc"))) world_model.load_state_dict(torch.load(osp.join(path, "world_model.ptc"))) actor.to(ptu.device) world_model.to(ptu.device) eval_policy = DreamerLowLevelRAPSPolicy( world_model, actor, obs_dim, action_dim, num_low_level_actions_per_primitive=num_low_level_actions_per_primitive, low_level_action_dim=low_level_action_dim, exploration=False, expl_amount=0.0, discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, discrete_continuous_dist=discrete_continuous_dist, ) with torch.no_grad(): with torch.cuda.amp.autocast(): for step in range( 0, variant["algorithm_kwargs"]["max_path_length"] + 1): if step == 0: observation = eval_env.envs[0].reset() eval_policy.reset(observation.reshape(1, -1)) policy_o = (None, observation.reshape(1, -1)) reward = 0 else: high_level_action, _ = eval_policy.get_action(policy_o, ) observation, reward, done, info = eval_env.envs[0].step( high_level_action[0], ) low_level_obs = np.expand_dims( np.array(info["low_level_obs"]), 0) low_level_action = np.expand_dims( np.array(info["low_level_action"]), 0) policy_o = (low_level_action, low_level_obs) return reward
image_kwargs=dict(imwidth=64, imheight=64), collect_primitives_info=True, include_phase_variable=True, render_intermediate_obs_to_info=not args.collect_data_fn == "collect_primitive_cloning_data", num_low_level_actions_per_primitive=num_low_level_actions_per_primitive, ) datafile = "wm_H_{}_T_{}_E_{}_P_{}_raps_ll_hl_even_rt_{}".format( args.max_path_length, num_trajs, args.num_envs, num_low_level_actions_per_primitive, env_name, ) env_fns = [ lambda: make_env(env_suite, env_name, env_kwargs) for _ in range(args.num_envs) ] env = StableBaselinesVecEnv(env_fns=env_fns, start_method="fork") if args.collect_data_fn == "collect_world_model_data": data = collect_world_model_data( env, num_trajs * args.num_envs, args.num_envs, args.max_path_length, ) save_data(data, datafile) elif ( args.collect_data_fn == "collect_world_model_data_low_level_primitives" ):
def experiment(variant): import os import rlkit.envs.primitives_make_env as primitives_make_env os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1" import torch import rlkit.torch.pytorch_util as ptu from rlkit.envs.wrappers.mujoco_vec_wrappers import ( DummyVecEnv, StableBaselinesVecEnv, ) from rlkit.torch.model_based.dreamer.actor_models import ActorModel from rlkit.torch.model_based.dreamer.dreamer_policy import ( ActionSpaceSamplePolicy, DreamerPolicy, ) from rlkit.torch.model_based.dreamer.episode_replay_buffer import ( EpisodeReplayBuffer, ) from rlkit.torch.model_based.dreamer.mlp import Mlp from rlkit.torch.model_based.dreamer.path_collector import VecMdpPathCollector from rlkit.torch.model_based.dreamer.visualization import video_post_epoch_func from rlkit.torch.model_based.dreamer.world_models import WorldModel from rlkit.torch.model_based.plan2explore.latent_space_models import ( OneStepEnsembleModel, ) from rlkit.torch.model_based.plan2explore.plan2explore import Plan2ExploreTrainer from rlkit.torch.model_based.rl_algorithm import TorchBatchRLAlgorithm env_suite = variant.get("env_suite", "kitchen") env_name = variant["env_name"] env_kwargs = variant["env_kwargs"] use_raw_actions = variant["use_raw_actions"] num_expl_envs = variant["num_expl_envs"] actor_model_class_name = variant.get("actor_model_class", "actor_model") if num_expl_envs > 1: env_fns = [ lambda: primitives_make_env.make_env( env_suite, env_name, env_kwargs) for _ in range(num_expl_envs) ] expl_env = StableBaselinesVecEnv(env_fns=env_fns, start_method="fork") else: expl_envs = [ primitives_make_env.make_env(env_suite, env_name, env_kwargs) ] expl_env = DummyVecEnv(expl_envs, pass_render_kwargs=variant.get( "pass_render_kwargs", False)) eval_envs = [ primitives_make_env.make_env(env_suite, env_name, env_kwargs) for _ in range(1) ] eval_env = DummyVecEnv(eval_envs, pass_render_kwargs=variant.get( "pass_render_kwargs", False)) if use_raw_actions: discrete_continuous_dist = False continuous_action_dim = eval_env.action_space.low.size discrete_action_dim = 0 use_batch_length = True action_dim = continuous_action_dim else: discrete_continuous_dist = variant["actor_kwargs"][ "discrete_continuous_dist"] continuous_action_dim = eval_envs[0].max_arg_len discrete_action_dim = eval_envs[0].num_primitives if not discrete_continuous_dist: continuous_action_dim = continuous_action_dim + discrete_action_dim discrete_action_dim = 0 action_dim = continuous_action_dim + discrete_action_dim use_batch_length = False world_model_class = WorldModel obs_dim = expl_env.observation_space.low.size actor_model_class = ActorModel if variant.get("load_from_path", False): data = torch.load(variant["models_path"]) actor = data["trainer/actor"] vf = data["trainer/vf"] target_vf = data["trainer/target_vf"] world_model = data["trainer/world_model"] else: world_model = world_model_class( action_dim, image_shape=eval_envs[0].image_shape, **variant["model_kwargs"], env=eval_envs[0].env, ) actor = actor_model_class( variant["model_kwargs"]["model_hidden_size"], world_model.feature_size, hidden_activation=torch.nn.functional.elu, discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, env=eval_envs[0].env, **variant["actor_kwargs"], ) vf = Mlp( hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] * variant["vf_kwargs"]["num_layers"], output_size=1, input_size=world_model.feature_size, hidden_activation=torch.nn.functional.elu, ) target_vf = Mlp( hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] * variant["vf_kwargs"]["num_layers"], output_size=1, input_size=world_model.feature_size, hidden_activation=torch.nn.functional.elu, ) one_step_ensemble = OneStepEnsembleModel( action_dim=action_dim, embedding_size=variant["model_kwargs"]["embedding_size"], deterministic_state_size=variant["model_kwargs"] ["deterministic_state_size"], stochastic_state_size=variant["model_kwargs"]["stochastic_state_size"], **variant["one_step_ensemble_kwargs"], ) exploration_actor = actor_model_class( variant["model_kwargs"]["model_hidden_size"], world_model.feature_size, hidden_activation=torch.nn.functional.elu, discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, env=eval_envs[0], **variant["actor_kwargs"], ) exploration_vf = Mlp( hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] * variant["vf_kwargs"]["num_layers"], output_size=1, input_size=world_model.feature_size, hidden_activation=torch.nn.functional.elu, ) exploration_target_vf = Mlp( hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] * variant["vf_kwargs"]["num_layers"], output_size=1, input_size=world_model.feature_size, hidden_activation=torch.nn.functional.elu, ) if variant.get("expl_with_exploration_actor", True): expl_actor = exploration_actor else: expl_actor = actor expl_policy = DreamerPolicy( world_model, expl_actor, obs_dim, action_dim, exploration=True, expl_amount=variant.get("expl_amount", 0.3), discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, discrete_continuous_dist=variant["actor_kwargs"] ["discrete_continuous_dist"], ) if variant.get("eval_with_exploration_actor", False): eval_actor = exploration_actor else: eval_actor = actor eval_policy = DreamerPolicy( world_model, eval_actor, obs_dim, action_dim, exploration=False, expl_amount=0.0, discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, discrete_continuous_dist=variant["actor_kwargs"] ["discrete_continuous_dist"], ) rand_policy = ActionSpaceSamplePolicy(expl_env) expl_path_collector = VecMdpPathCollector( expl_env, expl_policy, save_env_in_snapshot=False, ) eval_path_collector = VecMdpPathCollector( eval_env, eval_policy, save_env_in_snapshot=False, ) replay_buffer = EpisodeReplayBuffer( variant["replay_buffer_size"], expl_env, variant["algorithm_kwargs"]["max_path_length"] + 1, obs_dim, action_dim, replace=False, use_batch_length=use_batch_length, ) trainer = Plan2ExploreTrainer( eval_env, actor, vf, target_vf, world_model, eval_envs[0].image_shape, exploration_actor, exploration_vf, exploration_target_vf, one_step_ensemble, **variant["trainer_kwargs"], ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, pretrain_policy=rand_policy, **variant["algorithm_kwargs"], ) algorithm.post_epoch_funcs.append(video_post_epoch_func) algorithm.to(ptu.device) algorithm.train() video_post_epoch_func(algorithm, -1)
def test_run_assembly_success(): env_suite = "metaworld" env_name = "assembly-v2" env_kwargs = dict( use_image_obs=True, imwidth=64, imheight=64, reward_type="sparse", usage_kwargs=dict( use_dm_backend=True, use_raw_action_wrappers=False, unflatten_images=False, max_path_length=5, ), action_space_kwargs=dict( control_mode="primitives", action_scale=1, camera_settings={ "distance": 0.38227044687537043, "lookat": [0.21052547, 0.32329237, 0.587819], "azimuth": 141.328125, "elevation": -53.203125160653144, }, ), ) render_mode = "rgb_array" render_im_shape = (64, 64) render_every_step = True env = make_env( env_suite, env_name, env_kwargs, ) o = env.reset() for i in range(5): a = env.action_space.sample() a = np.zeros_like(a) if i % 5 == 0: primitive = "top_x_y_grasp" a[env.get_idx_from_primitive_name(primitive)] = 1 a[env.num_primitives + np.array(env.primitive_name_to_action_idx[primitive])] = [ 0.25, 0.0, -0.6, 1 ] elif i % 5 == 1: primitive = "lift" a[env.get_idx_from_primitive_name(primitive)] = 1 a[env.num_primitives + np.array(env.primitive_name_to_action_idx[primitive])] = 0.4 elif i % 5 == 2: primitive = "move_forward" a[env.get_idx_from_primitive_name(primitive)] = 1 a[env.num_primitives + np.array(env.primitive_name_to_action_idx[primitive])] = 0.45 elif i % 5 == 3: primitive = "move_right" a[env.get_idx_from_primitive_name(primitive)] = 1 a[env.num_primitives + np.array(env.primitive_name_to_action_idx[primitive])] = 0.05 elif i % 5 == 3: primitive = "open_gripper" a[env.get_idx_from_primitive_name(primitive)] = 1 a[env.num_primitives + np.array(env.primitive_name_to_action_idx[primitive])] = 1 o, r, d, info = env.step( a, render_every_step=render_every_step, render_mode=render_mode, render_im_shape=render_im_shape, ) assert r == 1.0
def experiment(variant): import numpy as np import torch from torch import nn, optim from tqdm import tqdm import rlkit.torch.pytorch_util as ptu from rlkit.core import logger from rlkit.envs.primitives_make_env import make_env from rlkit.torch.model_based.dreamer.mlp import Mlp, MlpResidual from rlkit.torch.model_based.dreamer.train_world_model import ( compute_world_model_loss, get_dataloader, get_dataloader_rt, get_dataloader_separately, update_network, visualize_rollout, world_model_loss_rt, ) from rlkit.torch.model_based.dreamer.world_models import ( LowlevelRAPSWorldModel, WorldModel, ) env_suite, env_name, env_kwargs = ( variant["env_suite"], variant["env_name"], variant["env_kwargs"], ) max_path_length = variant["env_kwargs"]["max_path_length"] low_level_primitives = variant["low_level_primitives"] num_low_level_actions_per_primitive = variant[ "num_low_level_actions_per_primitive"] low_level_action_dim = variant["low_level_action_dim"] dataloader_kwargs = variant["dataloader_kwargs"] env = make_env(env_suite, env_name, env_kwargs) world_model_kwargs = variant["model_kwargs"] optimizer_kwargs = variant["optimizer_kwargs"] gradient_clip = variant["gradient_clip"] if low_level_primitives: world_model_kwargs["action_dim"] = low_level_action_dim else: world_model_kwargs["action_dim"] = env.action_space.low.shape[0] image_shape = env.image_shape world_model_kwargs["image_shape"] = image_shape scaler = torch.cuda.amp.GradScaler() world_model_loss_kwargs = variant["world_model_loss_kwargs"] clone_primitives = variant["clone_primitives"] clone_primitives_separately = variant["clone_primitives_separately"] clone_primitives_and_train_world_model = variant.get( "clone_primitives_and_train_world_model", False) batch_len = variant.get("batch_len", 100) num_epochs = variant["num_epochs"] loss_to_use = variant.get("loss_to_use", "both") logdir = logger.get_snapshot_dir() if clone_primitives_separately: ( train_dataloaders, test_dataloaders, train_datasets, test_datasets, ) = get_dataloader_separately( variant["datafile"], num_low_level_actions_per_primitive= num_low_level_actions_per_primitive, num_primitives=env.num_primitives, env=env, **dataloader_kwargs, ) elif clone_primitives_and_train_world_model: print("LOADING DATA") ( train_dataloader, test_dataloader, train_dataset, test_dataset, ) = get_dataloader_rt( variant["datafile"], max_path_length=max_path_length * num_low_level_actions_per_primitive + 1, **dataloader_kwargs, ) elif low_level_primitives or clone_primitives: print("LOADING DATA") ( train_dataloader, test_dataloader, train_dataset, test_dataset, ) = get_dataloader( variant["datafile"], max_path_length=max_path_length * num_low_level_actions_per_primitive + 1, **dataloader_kwargs, ) else: train_dataloader, test_dataloader, train_dataset, test_dataset = get_dataloader( variant["datafile"], max_path_length=max_path_length + 1, **dataloader_kwargs, ) if clone_primitives_and_train_world_model: if variant["mlp_act"] == "elu": mlp_act = nn.functional.elu elif variant["mlp_act"] == "relu": mlp_act = nn.functional.relu if variant["mlp_res"]: mlp_class = MlpResidual else: mlp_class = Mlp criterion = nn.MSELoss() primitive_model = mlp_class( hidden_sizes=variant["mlp_hidden_sizes"], output_size=low_level_action_dim, input_size=250 + env.action_space.low.shape[0] + 1, hidden_activation=mlp_act, ).to(ptu.device) world_model_class = LowlevelRAPSWorldModel world_model = world_model_class( primitive_model=primitive_model, **world_model_kwargs, ).to(ptu.device) optimizer = optim.Adam( world_model.parameters(), **optimizer_kwargs, ) best_test_loss = np.inf for i in tqdm(range(num_epochs)): eval_statistics = OrderedDict() print("Epoch: ", i) total_primitive_loss = 0 total_world_model_loss = 0 total_div_loss = 0 total_image_pred_loss = 0 total_transition_loss = 0 total_entropy_loss = 0 total_pred_discount_loss = 0 total_reward_pred_loss = 0 total_train_steps = 0 for data in train_dataloader: with torch.cuda.amp.autocast(): ( high_level_actions, obs, rewards, terminals, ), low_level_actions = data obs = obs.to(ptu.device).float() low_level_actions = low_level_actions.to( ptu.device).float() high_level_actions = high_level_actions.to( ptu.device).float() rewards = rewards.to(ptu.device).float() terminals = terminals.to(ptu.device).float() assert all(terminals[:, -1] == 1) rt_idxs = np.arange( num_low_level_actions_per_primitive, obs.shape[1], num_low_level_actions_per_primitive, ) rt_idxs = np.concatenate( [[0], rt_idxs] ) # reset obs, effect of first primitive, second primitive, so on batch_start = np.random.randint(0, obs.shape[1] - batch_len, size=(obs.shape[0])) batch_indices = np.linspace( batch_start, batch_start + batch_len, batch_len, endpoint=False, ).astype(int) ( post, prior, post_dist, prior_dist, image_dist, reward_dist, pred_discount_dist, _, action_preds, ) = world_model( obs, (high_level_actions, low_level_actions), use_network_action=False, batch_indices=batch_indices, rt_idxs=rt_idxs, ) obs = world_model.flatten_obs( obs[np.arange(batch_indices.shape[1]), batch_indices].permute(1, 0, 2), (int(np.prod(image_shape)), ), ) rewards = rewards.reshape(-1, rewards.shape[-1]) terminals = terminals.reshape(-1, terminals.shape[-1]) ( world_model_loss, div, image_pred_loss, reward_pred_loss, transition_loss, entropy_loss, pred_discount_loss, ) = world_model_loss_rt( world_model, image_shape, image_dist, reward_dist, { key: value[np.arange(batch_indices.shape[1]), batch_indices].permute(1, 0, 2).reshape( -1, value.shape[-1]) for key, value in prior.items() }, { key: value[np.arange(batch_indices.shape[1]), batch_indices].permute(1, 0, 2).reshape( -1, value.shape[-1]) for key, value in post.items() }, prior_dist, post_dist, pred_discount_dist, obs, rewards, terminals, **world_model_loss_kwargs, ) batch_start = np.random.randint( 0, low_level_actions.shape[1] - batch_len, size=(low_level_actions.shape[0]), ) batch_indices = np.linspace( batch_start, batch_start + batch_len, batch_len, endpoint=False, ).astype(int) primitive_loss = criterion( action_preds[np.arange(batch_indices.shape[1]), batch_indices].permute(1, 0, 2).reshape( -1, action_preds.shape[-1]), low_level_actions[:, 1:] [np.arange(batch_indices.shape[1]), batch_indices].permute(1, 0, 2).reshape( -1, action_preds.shape[-1]), ) total_primitive_loss += primitive_loss.item() total_world_model_loss += world_model_loss.item() total_div_loss += div.item() total_image_pred_loss += image_pred_loss.item() total_transition_loss += transition_loss.item() total_entropy_loss += entropy_loss.item() total_pred_discount_loss += pred_discount_loss.item() total_reward_pred_loss += reward_pred_loss.item() if loss_to_use == "wm": loss = world_model_loss elif loss_to_use == "primitive": loss = primitive_loss else: loss = world_model_loss + primitive_loss total_train_steps += 1 update_network(world_model, optimizer, loss, gradient_clip, scaler) scaler.update() eval_statistics["train/primitive_loss"] = (total_primitive_loss / total_train_steps) eval_statistics["train/world_model_loss"] = ( total_world_model_loss / total_train_steps) eval_statistics["train/image_pred_loss"] = (total_image_pred_loss / total_train_steps) eval_statistics["train/transition_loss"] = (total_transition_loss / total_train_steps) eval_statistics["train/entropy_loss"] = (total_entropy_loss / total_train_steps) eval_statistics["train/pred_discount_loss"] = ( total_pred_discount_loss / total_train_steps) eval_statistics["train/reward_pred_loss"] = ( total_reward_pred_loss / total_train_steps) latest_state_dict = world_model.state_dict().copy() with torch.no_grad(): total_primitive_loss = 0 total_world_model_loss = 0 total_div_loss = 0 total_image_pred_loss = 0 total_transition_loss = 0 total_entropy_loss = 0 total_pred_discount_loss = 0 total_reward_pred_loss = 0 total_loss = 0 total_test_steps = 0 for data in test_dataloader: with torch.cuda.amp.autocast(): ( high_level_actions, obs, rewards, terminals, ), low_level_actions = data obs = obs.to(ptu.device).float() low_level_actions = low_level_actions.to( ptu.device).float() high_level_actions = high_level_actions.to( ptu.device).float() rewards = rewards.to(ptu.device).float() terminals = terminals.to(ptu.device).float() assert all(terminals[:, -1] == 1) rt_idxs = np.arange( num_low_level_actions_per_primitive, obs.shape[1], num_low_level_actions_per_primitive, ) rt_idxs = np.concatenate( [[0], rt_idxs] ) # reset obs, effect of first primitive, second primitive, so on batch_start = np.random.randint(0, obs.shape[1] - batch_len, size=(obs.shape[0])) batch_indices = np.linspace( batch_start, batch_start + batch_len, batch_len, endpoint=False, ).astype(int) ( post, prior, post_dist, prior_dist, image_dist, reward_dist, pred_discount_dist, _, action_preds, ) = world_model( obs, (high_level_actions, low_level_actions), use_network_action=False, batch_indices=batch_indices, rt_idxs=rt_idxs, ) obs = world_model.flatten_obs( obs[np.arange(batch_indices.shape[1]), batch_indices].permute(1, 0, 2), (int(np.prod(image_shape)), ), ) rewards = rewards.reshape(-1, rewards.shape[-1]) terminals = terminals.reshape(-1, terminals.shape[-1]) ( world_model_loss, div, image_pred_loss, reward_pred_loss, transition_loss, entropy_loss, pred_discount_loss, ) = world_model_loss_rt( world_model, image_shape, image_dist, reward_dist, { key: value[np.arange(batch_indices.shape[1]), batch_indices].permute( 1, 0, 2).reshape( -1, value.shape[-1]) for key, value in prior.items() }, { key: value[np.arange(batch_indices.shape[1]), batch_indices].permute( 1, 0, 2).reshape( -1, value.shape[-1]) for key, value in post.items() }, prior_dist, post_dist, pred_discount_dist, obs, rewards, terminals, **world_model_loss_kwargs, ) batch_start = np.random.randint( 0, low_level_actions.shape[1] - batch_len, size=(low_level_actions.shape[0]), ) batch_indices = np.linspace( batch_start, batch_start + batch_len, batch_len, endpoint=False, ).astype(int) primitive_loss = criterion( action_preds[np.arange(batch_indices.shape[1]), batch_indices].permute( 1, 0, 2).reshape( -1, action_preds.shape[-1]), low_level_actions[:, 1:] [np.arange(batch_indices.shape[1]), batch_indices].permute(1, 0, 2).reshape( -1, action_preds.shape[-1]), ) total_primitive_loss += primitive_loss.item() total_world_model_loss += world_model_loss.item() total_div_loss += div.item() total_image_pred_loss += image_pred_loss.item() total_transition_loss += transition_loss.item() total_entropy_loss += entropy_loss.item() total_pred_discount_loss += pred_discount_loss.item() total_reward_pred_loss += reward_pred_loss.item() total_loss += world_model_loss.item( ) + primitive_loss.item() total_test_steps += 1 eval_statistics["test/primitive_loss"] = ( total_primitive_loss / total_test_steps) eval_statistics["test/world_model_loss"] = ( total_world_model_loss / total_test_steps) eval_statistics["test/image_pred_loss"] = ( total_image_pred_loss / total_test_steps) eval_statistics["test/transition_loss"] = ( total_transition_loss / total_test_steps) eval_statistics["test/entropy_loss"] = (total_entropy_loss / total_test_steps) eval_statistics["test/pred_discount_loss"] = ( total_pred_discount_loss / total_test_steps) eval_statistics["test/reward_pred_loss"] = ( total_reward_pred_loss / total_test_steps) if (total_loss / total_test_steps) <= best_test_loss: best_test_loss = total_loss / total_test_steps os.makedirs(logdir + "/models/", exist_ok=True) best_wm_state_dict = world_model.state_dict().copy() torch.save( best_wm_state_dict, logdir + "/models/world_model.pt", ) if i % variant["plotting_period"] == 0: print("Best test loss", best_test_loss) world_model.load_state_dict(best_wm_state_dict) visualize_wm( env, world_model, train_dataset.outputs, train_dataset.inputs[1], test_dataset.outputs, test_dataset.inputs[1], logdir, max_path_length, low_level_primitives, num_low_level_actions_per_primitive, primitive_model=primitive_model, ) world_model.load_state_dict(latest_state_dict) logger.record_dict(eval_statistics, prefix="") logger.dump_tabular(with_prefix=False, with_timestamp=False) elif clone_primitives_separately: world_model.load_state_dict(torch.load(variant["world_model_path"])) criterion = nn.MSELoss() primitives = [] for p in range(env.num_primitives): arguments_size = train_datasets[p].inputs[0].shape[-1] m = Mlp( hidden_sizes=variant["mlp_hidden_sizes"], output_size=low_level_action_dim, input_size=world_model.feature_size + arguments_size, hidden_activation=torch.nn.functional.relu, ).to(ptu.device) if variant.get("primitives_path", None): m.load_state_dict( torch.load(variant["primitives_path"] + "primitive_model_{}.pt".format(p))) primitives.append(m) optimizers = [ optim.Adam(p.parameters(), **optimizer_kwargs) for p in primitives ] for i in tqdm(range(num_epochs)): if i % variant["plotting_period"] == 0: visualize_rollout( env, None, None, world_model, logdir, max_path_length, use_env=True, forcing="none", tag="none", low_level_primitives=low_level_primitives, num_low_level_actions_per_primitive= num_low_level_actions_per_primitive, primitive_model=primitives, use_separate_primitives=True, ) visualize_rollout( env, None, None, world_model, logdir, max_path_length, use_env=True, forcing="teacher", tag="none", low_level_primitives=low_level_primitives, num_low_level_actions_per_primitive= num_low_level_actions_per_primitive, primitive_model=primitives, use_separate_primitives=True, ) visualize_rollout( env, None, None, world_model, logdir, max_path_length, use_env=True, forcing="self", tag="none", low_level_primitives=low_level_primitives, num_low_level_actions_per_primitive= num_low_level_actions_per_primitive, primitive_model=primitives, use_separate_primitives=True, ) eval_statistics = OrderedDict() print("Epoch: ", i) for p, ( train_dataloader, test_dataloader, primitive_model, optimizer, ) in enumerate( zip(train_dataloaders, test_dataloaders, primitives, optimizers)): total_loss = 0 total_train_steps = 0 for data in train_dataloader: with torch.cuda.amp.autocast(): (arguments, obs), actions = data obs = obs.to(ptu.device).float() actions = actions.to(ptu.device).float() arguments = arguments.to(ptu.device).float() action_preds = world_model( obs, (arguments, actions), primitive_model, use_network_action=False, )[-1] loss = criterion(action_preds, actions) total_loss += loss.item() total_train_steps += 1 update_network(primitive_model, optimizer, loss, gradient_clip, scaler) scaler.update() eval_statistics["train/primitive_loss {}".format(p)] = ( total_loss / total_train_steps) best_test_loss = np.inf with torch.no_grad(): total_loss = 0 total_test_steps = 0 for data in test_dataloader: with torch.cuda.amp.autocast(): (high_level_actions, obs), actions = data obs = obs.to(ptu.device).float() actions = actions.to(ptu.device).float() high_level_actions = high_level_actions.to( ptu.device).float() action_preds = world_model( obs, (high_level_actions, actions), primitive_model, use_network_action=False, )[-1] loss = criterion(action_preds, actions) total_loss += loss.item() total_test_steps += 1 eval_statistics["test/primitive_loss {}".format(p)] = ( total_loss / total_test_steps) if (total_loss / total_test_steps) <= best_test_loss: best_test_loss = total_loss / total_test_steps os.makedirs(logdir + "/models/", exist_ok=True) torch.save( primitive_model.state_dict(), logdir + "/models/primitive_model_{}.pt".format(p), ) logger.record_dict(eval_statistics, prefix="") logger.dump_tabular(with_prefix=False, with_timestamp=False) visualize_rollout( env, None, None, world_model, logdir, max_path_length, use_env=True, forcing="none", tag="none", low_level_primitives=low_level_primitives, num_low_level_actions_per_primitive= num_low_level_actions_per_primitive, primitive_model=primitives, use_separate_primitives=True, ) elif clone_primitives: world_model.load_state_dict(torch.load(variant["world_model_path"])) criterion = nn.MSELoss() primitive_model = Mlp( hidden_sizes=variant["mlp_hidden_sizes"], output_size=low_level_action_dim, input_size=world_model.feature_size + env.action_space.low.shape[0] + 1, hidden_activation=torch.nn.functional.relu, ).to(ptu.device) optimizer = optim.Adam( primitive_model.parameters(), **optimizer_kwargs, ) for i in tqdm(range(num_epochs)): if i % variant["plotting_period"] == 0: visualize_rollout( env, None, None, world_model, logdir, max_path_length, use_env=True, forcing="none", tag="none", low_level_primitives=low_level_primitives, num_low_level_actions_per_primitive= num_low_level_actions_per_primitive, primitive_model=primitive_model, ) visualize_rollout( env, train_dataset.outputs, train_dataset.inputs[1], world_model, logdir, max_path_length, use_env=False, forcing="teacher", tag="train", low_level_primitives=low_level_primitives, num_low_level_actions_per_primitive= num_low_level_actions_per_primitive - 1, ) visualize_rollout( env, test_dataset.outputs, test_dataset.inputs[1], world_model, logdir, max_path_length, use_env=False, forcing="teacher", tag="test", low_level_primitives=low_level_primitives, num_low_level_actions_per_primitive= num_low_level_actions_per_primitive - 1, ) eval_statistics = OrderedDict() print("Epoch: ", i) total_loss = 0 total_train_steps = 0 for data in train_dataloader: with torch.cuda.amp.autocast(): (high_level_actions, obs), actions = data obs = obs.to(ptu.device).float() actions = actions.to(ptu.device).float() high_level_actions = high_level_actions.to( ptu.device).float() action_preds = world_model( obs, (high_level_actions, actions), primitive_model, use_network_action=False, )[-1] loss = criterion(action_preds, actions) total_loss += loss.item() total_train_steps += 1 update_network(primitive_model, optimizer, loss, gradient_clip, scaler) scaler.update() eval_statistics[ "train/primitive_loss"] = total_loss / total_train_steps best_test_loss = np.inf with torch.no_grad(): total_loss = 0 total_test_steps = 0 for data in test_dataloader: with torch.cuda.amp.autocast(): (high_level_actions, obs), actions = data obs = obs.to(ptu.device).float() actions = actions.to(ptu.device).float() high_level_actions = high_level_actions.to( ptu.device).float() action_preds = world_model( obs, (high_level_actions, actions), primitive_model, use_network_action=False, )[-1] loss = criterion(action_preds, actions) total_loss += loss.item() total_test_steps += 1 eval_statistics[ "test/primitive_loss"] = total_loss / total_test_steps if (total_loss / total_test_steps) <= best_test_loss: best_test_loss = total_loss / total_test_steps os.makedirs(logdir + "/models/", exist_ok=True) torch.save( primitive_model.state_dict(), logdir + "/models/primitive_model.pt", ) logger.record_dict(eval_statistics, prefix="") logger.dump_tabular(with_prefix=False, with_timestamp=False) else: world_model = WorldModel(**world_model_kwargs).to(ptu.device) optimizer = optim.Adam( world_model.parameters(), **optimizer_kwargs, ) for i in tqdm(range(num_epochs)): if i % variant["plotting_period"] == 0: visualize_wm( env, world_model, train_dataset.inputs, train_dataset.outputs, test_dataset.inputs, test_dataset.outputs, logdir, max_path_length, low_level_primitives, num_low_level_actions_per_primitive, ) eval_statistics = OrderedDict() print("Epoch: ", i) total_wm_loss = 0 total_div_loss = 0 total_image_pred_loss = 0 total_transition_loss = 0 total_entropy_loss = 0 total_train_steps = 0 for data in train_dataloader: with torch.cuda.amp.autocast(): actions, obs = data obs = obs.to(ptu.device).float() actions = actions.to(ptu.device).float() post, prior, post_dist, prior_dist, image_dist = world_model( obs, actions)[:5] obs = world_model.flatten_obs(obs.permute( 1, 0, 2), (int(np.prod(image_shape)), )) ( world_model_loss, div, image_pred_loss, transition_loss, entropy_loss, ) = compute_world_model_loss( world_model, image_shape, image_dist, prior, post, prior_dist, post_dist, obs, **world_model_loss_kwargs, ) total_wm_loss += world_model_loss.item() total_div_loss += div.item() total_image_pred_loss += image_pred_loss.item() total_transition_loss += transition_loss.item() total_entropy_loss += entropy_loss.item() total_train_steps += 1 update_network(world_model, optimizer, world_model_loss, gradient_clip, scaler) scaler.update() eval_statistics[ "train/wm_loss"] = total_wm_loss / total_train_steps eval_statistics[ "train/div_loss"] = total_div_loss / total_train_steps eval_statistics["train/image_pred_loss"] = (total_image_pred_loss / total_train_steps) eval_statistics["train/transition_loss"] = (total_transition_loss / total_train_steps) eval_statistics["train/entropy_loss"] = (total_entropy_loss / total_train_steps) best_test_loss = np.inf with torch.no_grad(): total_wm_loss = 0 total_div_loss = 0 total_image_pred_loss = 0 total_transition_loss = 0 total_entropy_loss = 0 total_train_steps = 0 total_test_steps = 0 for data in test_dataloader: with torch.cuda.amp.autocast(): actions, obs = data obs = obs.to(ptu.device).float() actions = actions.to(ptu.device).float() post, prior, post_dist, prior_dist, image_dist = world_model( obs, actions)[:5] obs = world_model.flatten_obs(obs.permute( 1, 0, 2), (int(np.prod(image_shape)), )) ( world_model_loss, div, image_pred_loss, transition_loss, entropy_loss, ) = compute_world_model_loss( world_model, image_shape, image_dist, prior, post, prior_dist, post_dist, obs, **world_model_loss_kwargs, ) total_wm_loss += world_model_loss.item() total_div_loss += div.item() total_image_pred_loss += image_pred_loss.item() total_transition_loss += transition_loss.item() total_entropy_loss += entropy_loss.item() total_test_steps += 1 eval_statistics[ "test/wm_loss"] = total_wm_loss / total_test_steps eval_statistics[ "test/div_loss"] = total_div_loss / total_test_steps eval_statistics["test/image_pred_loss"] = ( total_image_pred_loss / total_test_steps) eval_statistics["test/transition_loss"] = ( total_transition_loss / total_test_steps) eval_statistics["test/entropy_loss"] = (total_entropy_loss / total_test_steps) if (total_wm_loss / total_test_steps) <= best_test_loss: best_test_loss = total_wm_loss / total_test_steps os.makedirs(logdir + "/models/", exist_ok=True) torch.save( world_model.state_dict(), logdir + "/models/world_model.pt", ) logger.record_dict(eval_statistics, prefix="") logger.dump_tabular(with_prefix=False, with_timestamp=False) world_model.load_state_dict( torch.load(logdir + "/models/world_model.pt")) visualize_wm( env, world_model, train_dataset, test_dataset, logdir, max_path_length, low_level_primitives, num_low_level_actions_per_primitive, )
def test_trainer_save_load(): env_kwargs = dict( use_image_obs=True, imwidth=64, imheight=64, reward_type="sparse", usage_kwargs=dict( max_path_length=5, use_dm_backend=True, use_raw_action_wrappers=False, unflatten_images=False, ), action_space_kwargs=dict( control_mode="primitives", action_scale=1, camera_settings={ "distance": 0.38227044687537043, "lookat": [0.21052547, 0.32329237, 0.587819], "azimuth": 141.328125, "elevation": -53.203125160653144, }, ), ) actor_kwargs = dict( discrete_continuous_dist=True, init_std=0.0, num_layers=4, min_std=0.1, dist="tanh_normal_dreamer_v1", ) vf_kwargs = dict(num_layers=3, ) model_kwargs = dict( model_hidden_size=400, stochastic_state_size=50, deterministic_state_size=200, rssm_hidden_size=200, reward_num_layers=2, pred_discount_num_layers=3, gru_layer_norm=True, std_act="sigmoid2", use_prior_instead_of_posterior=False, ) trainer_kwargs = dict( adam_eps=1e-5, discount=0.8, lam=0.95, forward_kl=False, free_nats=1.0, pred_discount_loss_scale=10.0, kl_loss_scale=0.0, transition_loss_scale=0.8, actor_lr=8e-5, vf_lr=8e-5, world_model_lr=3e-4, reward_loss_scale=2.0, use_pred_discount=True, policy_gradient_loss_scale=1.0, actor_entropy_loss_schedule="1e-4", target_update_period=100, detach_rewards=False, imagination_horizon=5, ) env_suite = "metaworld" env_name = "disassemble-v2" eval_envs = [make_env(env_suite, env_name, env_kwargs) for _ in range(1)] eval_env = DummyVecEnv(eval_envs, ) discrete_continuous_dist = True continuous_action_dim = eval_envs[0].max_arg_len discrete_action_dim = eval_envs[0].num_primitives if not discrete_continuous_dist: continuous_action_dim = continuous_action_dim + discrete_action_dim discrete_action_dim = 0 action_dim = continuous_action_dim + discrete_action_dim world_model = WorldModel( action_dim, image_shape=eval_envs[0].image_shape, **model_kwargs, ) actor = ActorModel( model_kwargs["model_hidden_size"], world_model.feature_size, hidden_activation=nn.ELU, discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, **actor_kwargs, ) vf = Mlp( hidden_sizes=[model_kwargs["model_hidden_size"]] * vf_kwargs["num_layers"], output_size=1, input_size=world_model.feature_size, hidden_activation=nn.ELU, ) target_vf = Mlp( hidden_sizes=[model_kwargs["model_hidden_size"]] * vf_kwargs["num_layers"], output_size=1, input_size=world_model.feature_size, hidden_activation=nn.ELU, ) trainer = DreamerV2Trainer( actor, vf, target_vf, world_model, eval_envs[0].image_shape, **trainer_kwargs, ) with tempfile.TemporaryDirectory() as tmpdirname: trainer.save(tmpdirname, "trainer.pkl") trainer = DreamerV2Trainer( actor, vf, target_vf, world_model, eval_envs[0].image_shape, **trainer_kwargs, ) new_trainer = trainer.load(tmpdirname, "trainer.pkl")
def experiment(variant): import os import os.path as osp os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1" import torch import torch.nn as nn import rlkit.torch.pytorch_util as ptu from rlkit.core import logger from rlkit.envs.primitives_make_env import make_env from rlkit.envs.wrappers.mujoco_vec_wrappers import ( DummyVecEnv, StableBaselinesVecEnv, ) from rlkit.torch.model_based.dreamer.actor_models import ActorModel from rlkit.torch.model_based.dreamer.dreamer_policy import ( ActionSpaceSamplePolicy, DreamerLowLevelRAPSPolicy, ) from rlkit.torch.model_based.dreamer.dreamer_v2 import DreamerV2LowLevelRAPSTrainer from rlkit.torch.model_based.dreamer.episode_replay_buffer import ( EpisodeReplayBufferLowLevelRAPS, ) from rlkit.torch.model_based.dreamer.mlp import Mlp from rlkit.torch.model_based.dreamer.path_collector import VecMdpPathCollector from rlkit.torch.model_based.dreamer.rollout_functions import ( vec_rollout_low_level_raps, ) from rlkit.torch.model_based.dreamer.visualization import ( post_epoch_visualize_func, visualize_primitive_unsubsampled_rollout, ) from rlkit.torch.model_based.dreamer.world_models import LowlevelRAPSWorldModel from rlkit.torch.model_based.rl_algorithm import TorchBatchRLAlgorithm env_suite = variant.get("env_suite", "kitchen") env_kwargs = variant["env_kwargs"] num_expl_envs = variant["num_expl_envs"] num_low_level_actions_per_primitive = variant[ "num_low_level_actions_per_primitive"] low_level_action_dim = variant["low_level_action_dim"] print("MAKING ENVS") env_name = variant["env_name"] make_env_lambda = lambda: make_env(env_suite, env_name, env_kwargs) if num_expl_envs > 1: env_fns = [make_env_lambda for _ in range(num_expl_envs)] expl_env = StableBaselinesVecEnv( env_fns=env_fns, start_method="fork", reload_state_args=( num_expl_envs, make_env, (env_suite, env_name, env_kwargs), ), ) else: expl_envs = [make_env_lambda()] expl_env = DummyVecEnv(expl_envs, pass_render_kwargs=variant.get( "pass_render_kwargs", False)) eval_envs = [make_env_lambda() for _ in range(1)] eval_env = DummyVecEnv(eval_envs, pass_render_kwargs=variant.get( "pass_render_kwargs", False)) discrete_continuous_dist = variant["actor_kwargs"][ "discrete_continuous_dist"] num_primitives = eval_envs[0].num_primitives continuous_action_dim = eval_envs[0].max_arg_len discrete_action_dim = num_primitives if not discrete_continuous_dist: continuous_action_dim = continuous_action_dim + discrete_action_dim discrete_action_dim = 0 action_dim = continuous_action_dim + discrete_action_dim obs_dim = expl_env.observation_space.low.size primitive_model = Mlp( output_size=variant["low_level_action_dim"], input_size=variant["model_kwargs"]["stochastic_state_size"] + variant["model_kwargs"]["deterministic_state_size"] + eval_env.envs[0].action_space.low.shape[0] + 1, hidden_activation=nn.ReLU, num_embeddings=eval_envs[0].num_primitives, embedding_dim=eval_envs[0].num_primitives, embedding_slice=eval_envs[0].num_primitives, **variant["primitive_model_kwargs"], ) world_model = LowlevelRAPSWorldModel( low_level_action_dim, image_shape=eval_envs[0].image_shape, primitive_model=primitive_model, **variant["model_kwargs"], ) actor = ActorModel( variant["model_kwargs"]["model_hidden_size"], world_model.feature_size, hidden_activation=nn.ELU, discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, **variant["actor_kwargs"], ) vf = Mlp( hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] * variant["vf_kwargs"]["num_layers"], output_size=1, input_size=world_model.feature_size, hidden_activation=nn.ELU, ) target_vf = Mlp( hidden_sizes=[variant["model_kwargs"]["model_hidden_size"]] * variant["vf_kwargs"]["num_layers"], output_size=1, input_size=world_model.feature_size, hidden_activation=nn.ELU, ) if variant.get("models_path", None) is not None: filename = variant["models_path"] actor.load_state_dict(torch.load(osp.join(filename, "actor.ptc"))) vf.load_state_dict(torch.load(osp.join(filename, "vf.ptc"))) target_vf.load_state_dict( torch.load(osp.join(filename, "target_vf.ptc"))) world_model.load_state_dict( torch.load(osp.join(filename, "world_model.ptc"))) print("LOADED MODELS") expl_policy = DreamerLowLevelRAPSPolicy( world_model, actor, obs_dim, action_dim, num_low_level_actions_per_primitive=num_low_level_actions_per_primitive, low_level_action_dim=low_level_action_dim, exploration=True, expl_amount=variant.get("expl_amount", 0.3), discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, discrete_continuous_dist=discrete_continuous_dist, ) eval_policy = DreamerLowLevelRAPSPolicy( world_model, actor, obs_dim, action_dim, num_low_level_actions_per_primitive=num_low_level_actions_per_primitive, low_level_action_dim=low_level_action_dim, exploration=False, expl_amount=0.0, discrete_action_dim=discrete_action_dim, continuous_action_dim=continuous_action_dim, discrete_continuous_dist=discrete_continuous_dist, ) initial_data_collection_policy = ActionSpaceSamplePolicy(expl_env) rollout_function_kwargs = dict( num_low_level_actions_per_primitive=num_low_level_actions_per_primitive, low_level_action_dim=low_level_action_dim, num_primitives=num_primitives, ) expl_path_collector = VecMdpPathCollector( expl_env, expl_policy, save_env_in_snapshot=False, rollout_fn=vec_rollout_low_level_raps, rollout_function_kwargs=rollout_function_kwargs, ) eval_path_collector = VecMdpPathCollector( eval_env, eval_policy, save_env_in_snapshot=False, rollout_fn=vec_rollout_low_level_raps, rollout_function_kwargs=rollout_function_kwargs, ) replay_buffer = EpisodeReplayBufferLowLevelRAPS( num_expl_envs, obs_dim, action_dim, **variant["replay_buffer_kwargs"]) filename = variant.get("replay_buffer_path", None) if filename is not None: replay_buffer.load_buffer(filename, eval_env.envs[0].num_primitives) eval_filename = variant.get("eval_buffer_path", None) if eval_filename is not None: eval_buffer = EpisodeReplayBufferLowLevelRAPS( 1000, expl_env, variant["algorithm_kwargs"]["max_path_length"], num_low_level_actions_per_primitive, obs_dim, action_dim, low_level_action_dim, replace=False, ) eval_buffer.load_buffer(eval_filename, eval_env.envs[0].num_primitives) else: eval_buffer = None trainer = DreamerV2LowLevelRAPSTrainer( actor, vf, target_vf, world_model, eval_envs[0].image_shape, **variant["trainer_kwargs"], ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, pretrain_policy=initial_data_collection_policy, **variant["algorithm_kwargs"], eval_buffer=eval_buffer, ) algorithm.low_level_primitives = True if variant.get("generate_video", False): post_epoch_visualize_func(algorithm, 0) elif variant.get("unsubsampled_rollout", False): visualize_primitive_unsubsampled_rollout( make_env_lambda(), make_env_lambda(), make_env_lambda(), logger.get_snapshot_dir(), algorithm.max_path_length, num_low_level_actions_per_primitive, policy=eval_policy, img_size=64, num_rollouts=4, ) else: if variant.get("save_video", False): algorithm.post_epoch_funcs.append(post_epoch_visualize_func) print("TRAINING") algorithm.to(ptu.device) algorithm.train() if variant.get("save_video", False): post_epoch_visualize_func(algorithm, -1)
}, usage_kwargs=dict( use_dm_backend=True, use_raw_action_wrappers=False, use_image_obs=True, max_path_length=5, unflatten_images=False, ), image_kwargs=dict(imwidth=64, imheight=64), collect_primitives_info=True, include_phase_variable=True, ) env_suite = "metaworld" env_name = "reach-v2" env = make_env(env_suite, env_name, env_kwargs) file_path = osp.join("data/" + args.logdir + "/test.avi") a1 = env.action_space.sample() a1[0] = 100 obs = env.reset() o, r, d, i = env.step( a1, render_every_step=True, render_mode="rgb_array", render_im_shape=(480, 480), ) true_actions1 = np.array(i["actions"]) true_states1 = np.array(i["robot-states"])
def experiment(variant): gym.logger.set_level(40) work_dir = rlkit_logger.get_snapshot_dir() args = parse_args() seed = int(variant["seed"]) utils.set_seed_everywhere(seed) os.makedirs(work_dir, exist_ok=True) agent_kwargs = variant["agent_kwargs"] data_augs = agent_kwargs["data_augs"] encoder_type = agent_kwargs["encoder_type"] discrete_continuous_dist = agent_kwargs["discrete_continuous_dist"] env_suite = variant["env_suite"] env_name = variant["env_name"] env_kwargs = variant["env_kwargs"] pre_transform_image_size = variant["pre_transform_image_size"] image_size = variant["image_size"] frame_stack = variant["frame_stack"] batch_size = variant["batch_size"] replay_buffer_capacity = variant["replay_buffer_capacity"] num_train_steps = variant["num_train_steps"] num_eval_episodes = variant["num_eval_episodes"] eval_freq = variant["eval_freq"] action_repeat = variant["action_repeat"] init_steps = variant["init_steps"] log_interval = variant["log_interval"] use_raw_actions = variant["use_raw_actions"] pre_transform_image_size = ( pre_transform_image_size if "crop" in data_augs else image_size ) pre_transform_image_size = pre_transform_image_size if data_augs == "crop": pre_transform_image_size = 100 image_size = image_size elif data_augs == "translate": pre_transform_image_size = 100 image_size = 108 if env_suite == 'kitchen': env_kwargs['imwidth'] = pre_transform_image_size env_kwargs['imheight'] = pre_transform_image_size else: env_kwargs['image_kwargs']['imwidth'] = pre_transform_image_size env_kwargs['image_kwargs']['imheight'] = pre_transform_image_size expl_env = primitives_make_env.make_env(env_suite, env_name, env_kwargs) eval_env = primitives_make_env.make_env(env_suite, env_name, env_kwargs) # stack several consecutive frames together if encoder_type == "pixel": expl_env = utils.FrameStack(expl_env, k=frame_stack) eval_env = utils.FrameStack(eval_env, k=frame_stack) # make directory ts = time.gmtime() ts = time.strftime("%m-%d", ts) env_name = env_name exp_name = ( env_name + "-" + ts + "-im" + str(image_size) + "-b" + str(batch_size) + "-s" + str(seed) + "-" + encoder_type ) work_dir = work_dir + "/" + exp_name utils.make_dir(work_dir) video_dir = utils.make_dir(os.path.join(work_dir, "video")) model_dir = utils.make_dir(os.path.join(work_dir, "model")) buffer_dir = utils.make_dir(os.path.join(work_dir, "buffer")) video = VideoRecorder(video_dir if args.save_video else None) with open(os.path.join(work_dir, "args.json"), "w") as f: json.dump(vars(args), f, sort_keys=True, indent=4) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if use_raw_actions: continuous_action_dim = expl_env.action_space.low.size discrete_action_dim = 0 else: num_primitives = expl_env.num_primitives max_arg_len = expl_env.max_arg_len if discrete_continuous_dist: continuous_action_dim = max_arg_len discrete_action_dim = num_primitives else: continuous_action_dim = max_arg_len + num_primitives discrete_action_dim = 0 if encoder_type == "pixel": obs_shape = (3 * frame_stack, image_size, image_size) pre_aug_obs_shape = ( 3 * frame_stack, pre_transform_image_size, pre_transform_image_size, ) else: obs_shape = env.observation_space.shape pre_aug_obs_shape = obs_shape replay_buffer = utils.ReplayBuffer( obs_shape=pre_aug_obs_shape, action_size=continuous_action_dim + discrete_action_dim, capacity=replay_buffer_capacity, batch_size=batch_size, device=device, image_size=image_size, pre_image_size=pre_transform_image_size, ) agent = make_agent( obs_shape=obs_shape, continuous_action_dim=continuous_action_dim, discrete_action_dim=discrete_action_dim, args=args, device=device, agent_kwargs=agent_kwargs, ) L = Logger(work_dir, use_tb=args.save_tb) episode, episode_reward, done = 0, 0, True start_time = time.time() epoch_start_time = time.time() train_expl_st = time.time() total_train_expl_time = 0 all_infos = [] ep_infos = [] num_train_calls = 0 for step in range(num_train_steps): # evaluate agent periodically if step % eval_freq == 0: total_train_expl_time += time.time()-train_expl_st L.log("eval/episode", episode, step) evaluate( eval_env, agent, video, num_eval_episodes, L, step, encoder_type, data_augs, image_size, pre_transform_image_size, env_name, action_repeat, work_dir, seed, ) if args.save_model: agent.save_curl(model_dir, step) if args.save_buffer: replay_buffer.save(buffer_dir) train_expl_st = time.time() if done: if step > 0: if step % log_interval == 0: L.log("train/duration", time.time() - epoch_start_time, step) L.dump(step) if step % log_interval == 0: L.log("train/episode_reward", episode_reward, step) obs = expl_env.reset() done = False episode_reward = 0 episode_step = 0 episode += 1 if step % log_interval == 0: all_infos.append(ep_infos) L.log("train/episode", episode, step) statistics = compute_path_info(all_infos) rlkit_logger.record_dict(statistics, prefix="exploration/") rlkit_logger.record_tabular( "time/epoch (s)", time.time() - epoch_start_time ) rlkit_logger.record_tabular("time/total (s)", time.time() - start_time) rlkit_logger.record_tabular("time/training and exploration (s)", total_train_expl_time) rlkit_logger.record_tabular("trainer/num train calls", num_train_calls) rlkit_logger.record_tabular("exploration/num steps total", step) rlkit_logger.record_tabular("Epoch", step // log_interval) rlkit_logger.dump_tabular(with_prefix=False, with_timestamp=False) all_infos = [] epoch_start_time = time.time() ep_infos = [] # sample action for data collection if step < init_steps: action = expl_env.action_space.sample() else: with utils.eval_mode(agent): action = agent.sample_action(obs / 255.0) # run training update if step >= init_steps: num_updates = 1 for _ in range(num_updates): agent.update(replay_buffer, L, step) num_train_calls += 1 next_obs, reward, done, info = expl_env.step(action) ep_infos.append(info) # allow infinit bootstrap done_bool = ( 0 if episode_step + 1 == expl_env._max_episode_steps else float(done) ) episode_reward += reward replay_buffer.add(obs, action, reward, next_obs, done_bool) obs = next_obs episode_step += 1