def make_video(args): if args.pause: import ipdb ipdb.set_trace() data = pickle.load(open(args.file, "rb")) # joblib.load(args.file) if 'policy' in data: policy = data['policy'] elif 'evaluation/policy' in data: policy = data['evaluation/policy'] else: raise AttributeError if 'env' in data: env = data['env'] elif 'evaluation/env' in data: env = data['evaluation/env'] else: raise AttributeError if isinstance(env, RemoteRolloutEnv): env = env._wrapped_env print("Policy loaded") if args.gpu: ptu.set_gpu_mode(True) policy.to(ptu.device) else: ptu.set_gpu_mode(False) policy.to(ptu.device) if isinstance(env, VAEWrappedEnv): env.mode(args.mode) max_path_length = 100 observation_key = 'latent_observation' desired_goal_key = 'latent_desired_goal' rollout_function = rf.create_rollout_function( rf.multitask_rollout, observation_key=observation_key, desired_goal_key=desired_goal_key, ) env.mode(env._mode_map['video_env']) random_id = str(uuid.uuid4()).split('-')[0] dump_video( env, policy, 'rollouts_{}.mp4'.format(random_id), rollout_function, rows=3, columns=6, pad_length=0, pad_color=255, do_timer=True, horizon=max_path_length, dirname_to_save_images=None, subdirname="rollouts", imsize=48, )
def save_video(algo, epoch): if epoch % save_video_period == 0 or epoch == algo.num_epochs: filename = osp.join( logdir, 'video_{}_{epoch}_env.mp4'.format(tag, epoch=epoch), ) dump_video(env, policy, filename, rollout_function, imsize=imsize, **dump_video_kwargs)
def save_video(algo, epoch): if epoch % save_video_period == 0 or epoch >= algo.num_epochs - 1: if tag is not None and len(tag) > 0: filename = 'video_{}_{epoch}_env.mp4'.format(tag, epoch=epoch) else: filename = 'video_{epoch}_env.mp4'.format(epoch=epoch) filename = osp.join( logdir, filename, ) dump_video(env, policy, filename, rollout_function, imsize=imsize, **dump_video_kwargs)
def save_video(algo, epoch): if epoch % save_period == 0 or epoch == algo.num_epochs: filename = osp.join( logdir, 'video_{epoch}_env.mp4'.format(epoch=epoch)) dump_video(image_env, policy, filename, rollout_function, **dump_video_kwargs)
def grill_her_sac_experiment(variant): env = variant["env_class"](**variant['env_kwargs']) render = variant["render"] rdim = variant["rdim"] vae_path = variant["vae_paths"][str(rdim)] reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) if init_camera is None: camera_name = "topview" else: camera_name = None env = ImageEnv( env, 84, init_camera=init_camera, camera_name=camera_name, transpose=True, normalize=True, ) env = VAEWrappedEnv(env, vae_path, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) if variant['normalize']: env = NormalizedBoxEnv(env) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size hidden_sizes = variant.get('hidden_sizes', [400, 300]) qf = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) vf = ConcatMlp( input_size=obs_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=hidden_sizes, ) training_mode = variant.get("training_mode", "train") testing_mode = variant.get("testing_mode", "test") testing_env = pickle.loads(pickle.dumps(env)) testing_env.mode(testing_mode) training_env = pickle.loads(pickle.dumps(env)) training_env.mode(training_mode) relabeling_env = pickle.loads(pickle.dumps(env)) relabeling_env.mode(training_mode) relabeling_env.disable_render() video_vae_env = pickle.loads(pickle.dumps(env)) video_vae_env.mode("video_vae") video_goal_env = pickle.loads(pickle.dumps(env)) video_goal_env.mode("video_env") replay_buffer = ObsDictRelabelingBuffer( env=relabeling_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_kwargs']) variant["algo_kwargs"]["replay_buffer"] = replay_buffer algorithm = HerSac(testing_env, training_env=training_env, qf=qf, vf=vf, policy=policy, render=render, render_during_eval=render, observation_key=observation_key, desired_goal_key=desired_goal_key, **variant['algo_kwargs']) if ptu.gpu_enabled(): print("using GPU") qf.to(ptu.device) vf.to(ptu.device) policy.to(ptu.device) algorithm.to(ptu.device) for e in [testing_env, training_env, video_vae_env, video_goal_env]: e.vae.to(ptu.device) algorithm.train() if variant.get("save_video", True): logdir = logger.get_snapshot_dir() policy.train(False) filename = osp.join(logdir, 'video_final_env.mp4') rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) dump_video(video_goal_env, policy, filename, rollout_function) filename = osp.join(logdir, 'video_final_vae.mp4') dump_video(video_vae_env, policy, filename, rollout_function)
def grill_her_td3_experiment(variant): env = variant["env_class"](**variant['env_kwargs']) render = variant["render"] rdim = variant["rdim"] vae_path = variant["vae_paths"][str(rdim)] reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) if init_camera is None: camera_name = "topview" else: camera_name = None env = ImageEnv( env, 84, init_camera=init_camera, camera_name=camera_name, transpose=True, normalize=True, ) env = VAEWrappedEnv(env, vae_path, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) if variant['normalize']: env = NormalizedBoxEnv(env) exploration_type = variant['exploration_type'] exploration_noise = variant.get('exploration_noise', 0.1) if exploration_type == 'ou': es = OUStrategy(action_space=env.action_space) elif exploration_type == 'gaussian': es = GaussianStrategy( action_space=env.action_space, max_sigma=exploration_noise, min_sigma=exploration_noise, # Constant sigma ) elif exploration_type == 'epsilon': es = EpsilonGreedy( action_space=env.action_space, prob_random_action=exploration_noise, ) else: raise Exception("Invalid type: " + exploration_type) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size hidden_sizes = variant.get('hidden_sizes', [400, 300]) qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, hidden_sizes=hidden_sizes, ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) training_mode = variant.get("training_mode", "train") testing_mode = variant.get("testing_mode", "test") testing_env = pickle.loads(pickle.dumps(env)) testing_env.mode(testing_mode) training_env = pickle.loads(pickle.dumps(env)) training_env.mode(training_mode) relabeling_env = pickle.loads(pickle.dumps(env)) relabeling_env.mode(training_mode) relabeling_env.disable_render() video_vae_env = pickle.loads(pickle.dumps(env)) video_vae_env.mode("video_vae") video_goal_env = pickle.loads(pickle.dumps(env)) video_goal_env.mode("video_env") replay_buffer = ObsDictRelabelingBuffer( env=relabeling_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_kwargs']) variant["algo_kwargs"]["replay_buffer"] = replay_buffer algorithm = HerTd3(testing_env, training_env=training_env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, render=render, render_during_eval=render, observation_key=observation_key, desired_goal_key=desired_goal_key, **variant['algo_kwargs']) if ptu.gpu_enabled(): print("using GPU") algorithm.to(ptu.device) for e in [testing_env, training_env, video_vae_env, video_goal_env]: e.vae.to(ptu.device) algorithm.train() if variant.get("save_video", True): logdir = logger.get_snapshot_dir() policy.train(False) filename = osp.join(logdir, 'video_final_env.mp4') rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) dump_video(video_goal_env, policy, filename, rollout_function) filename = osp.join(logdir, 'video_final_vae.mp4') dump_video(video_vae_env, policy, filename, rollout_function)
def simulate_policy(args): data = joblib.load(args.file) if 'eval_policy' in data: policy = data['eval_policy'] elif 'policy' in data: policy = data['policy'] elif 'exploration_policy' in data: policy = data['exploration_policy'] else: raise Exception("No policy found in loaded dict. Keys: {}".format( data.keys())) max_tau = get_max_tau(args) env = data['env'] env.mode("video_env") env.decode_goals = True if hasattr(env, 'enable_render'): # some environments need to be reconfigured for visualization env.enable_render() if args.gpu: set_gpu_mode(True) policy.to(ptu.device) if hasattr(env, "vae"): env.vae.to(ptu.device) else: # make sure everything is on the CPU set_gpu_mode(False) policy.cpu() if hasattr(env, "vae"): env.vae.cpu() if args.pause: import ipdb ipdb.set_trace() if isinstance(policy, PyTorchModule): policy.train(False) ROWS = 3 COLUMNS = 6 dirname = osp.dirname(args.file) input_file_name = os.path.splitext(os.path.basename(args.file))[0] filename = osp.join(dirname, "video_{}.mp4".format(input_file_name)) rollout_function = create_rollout_function( tdm_rollout, init_tau=max_tau, observation_key='observation', desired_goal_key='desired_goal', ) paths = dump_video( env, policy, filename, rollout_function, ROWS=ROWS, COLUMNS=COLUMNS, horizon=args.H, dirname_to_save_images=dirname, subdirname="rollouts_" + input_file_name, ) if hasattr(env, "log_diagnostics"): env.log_diagnostics(paths) logger.dump_tabular()