def train_vae(variant, return_data=False): from rlkit.util.ml_util import PiecewiseLinearSchedule from rlkit.torch.vae.conv_vae import ConvVAE as conv_vae from rlkit.torch.vae.conv_vae import ConvVAE from rlkit.core import logger import rlkit.torch.pytorch_util as ptu from rlkit.pythonplusplus import identity import torch from rlkit.torch.vae.conv_vae import imsize48_default_architecture beta = variant["beta"] representation_size = variant.get("representation_size", 4) train_dataset, test_dataset = generate_vae_data(variant) decoder_activation = identity # train_dataset = train_dataset.cuda() # test_dataset = test_dataset.cuda() architecture = variant.get('vae_architecture', imsize48_default_architecture) image_size = variant.get('image_size', 48) input_channels = variant.get('input_channels', 1) vae_model = ConvVAE(representation_size, decoder_output_activation=decoder_activation, architecture=architecture, imsize=image_size, input_channels=input_channels, decoder_distribution='gaussian_identity_variance') vae_model.cuda() vae_trainner = ConvVAETrainer(train_dataset, test_dataset, vae_model, beta=beta, beta_schedule=None) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): vae_trainner.train_epoch(epoch) vae_trainner.test_epoch(epoch) if epoch % save_period == 0: vae_trainner.dump_samples(epoch) vae_trainner.update_train_weights() # logger.save_extra_data(vae_model, 'vae.pkl', mode='pickle') project_path = osp.abspath(os.curdir) save_dir = osp.join(project_path + str('/saved_model/'), 'vae_model.pkl') torch.save(vae_model.state_dict(), save_dir) # torch.save(vae_model.state_dict(), \ # '/mnt/manh/project/visual_RL_imaged_goal/saved_model/vae_model.pkl') if return_data: return vae_model, train_dataset, test_dataset return vae_model
def experiment(variant): from rlkit.core import logger import rlkit.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset( **variant['get_data_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = ConvVAE(representation_size, input_channels=3, **variant['conv_vae_kwargs']) if ptu.gpu_enabled(): m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def experiment(variant): beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data = get_data(10000) m = ConvVAE(representation_size) t = ConvVAETrainer(train_data, test_data, m, beta=beta, use_cuda=False) for epoch in range(10): t.train_epoch(epoch) t.test_epoch(epoch) t.dump_samples(epoch)
def train_vae(variant, return_data=False): beta = variant["beta"] representation_size = variant["representation_size"] generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) train_data, test_data, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') m = ConvVAE(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, ) if should_save_imgs: t.dump_samples(epoch) logger.save_extra_data(m, 'vae.pkl', mode='pickle') if return_data: return m, train_data, test_data return m
def experiment(variant): from rlkit.core import logger import rlkit.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = get_data(**variant['get_data_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs']) m = ConvVAE(representation_size, input_channels=3) if ptu.gpu_enabled(): m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) for epoch in range(variant['num_epochs']): t.train_epoch(epoch) t.test_epoch(epoch) t.dump_samples(epoch)
def tdm_td3_experiment_online_vae(variant): import rlkit.samplers.rollout_functions as rf import rlkit.torch.pytorch_util as ptu from rlkit.data_management.online_vae_replay_buffer import \ OnlineVaeRelabelingBuffer from rlkit.exploration_strategies.base import ( PolicyWrappedWithExplorationStrategy) from rlkit.state_distance.tdm_networks import TdmQf, TdmPolicy from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.torch.online_vae.online_vae_tdm_td3 import OnlineVaeTdmTd3 preprocess_rl_variant(variant) env = get_envs(variant) es = get_exploration_strategy(variant, env) vae_trainer_kwargs = variant.get('vae_trainer_kwargs') observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size) goal_dim = (env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size vectorized = 'vectorized' in env.reward_type variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][ 'vectorized'] = vectorized norm_order = env.norm_order # variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][ # 'norm_order'] = norm_order qf1 = TdmQf(env=env, vectorized=vectorized, norm_order=norm_order, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs']) qf2 = TdmQf(env=env, vectorized=vectorized, norm_order=norm_order, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['qf_kwargs']) policy = TdmPolicy(env=env, observation_dim=obs_dim, goal_dim=goal_dim, action_dim=action_dim, **variant['policy_kwargs']) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) vae = env.vae replay_buffer = OnlineVaeRelabelingBuffer( vae=vae, env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) algo_kwargs = variant['algo_kwargs']['tdm_td3_kwargs'] td3_kwargs = algo_kwargs['td3_kwargs'] td3_kwargs['training_env'] = env tdm_kwargs = algo_kwargs['tdm_kwargs'] tdm_kwargs['observation_key'] = observation_key tdm_kwargs['desired_goal_key'] = desired_goal_key algo_kwargs["replay_buffer"] = replay_buffer t = ConvVAETrainer(variant['vae_train_data'], variant['vae_test_data'], vae, beta=variant['online_vae_beta'], **vae_trainer_kwargs) render = variant["render"] assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs" algorithm = OnlineVaeTdmTd3( online_vae_kwargs=dict(vae=vae, vae_trainer=t, **variant['algo_kwargs']['online_vae_kwargs']), tdm_td3_kwargs=dict(env=env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs']['tdm_td3_kwargs']), ) algorithm.to(ptu.device) vae.to(ptu.device) if variant.get("save_video", True): policy.train(False) rollout_function = rf.create_rollout_function( rf.tdm_rollout, init_tau=algorithm._sample_max_tau_for_rollout(), decrement_tau=algorithm.cycle_taus_for_rollout, cycle_tau=algorithm.cycle_taus_for_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) video_func = get_video_save_func( rollout_function, env, algorithm.eval_policy, variant, ) algorithm.post_train_funcs.append(video_func) algorithm.to(ptu.device) if not variant.get("do_state_exp", False): env.vae.to(ptu.device) algorithm.train()
def td3_experiment_online_vae_exploring(variant): import rlkit.samplers.rollout_functions as rf import rlkit.torch.pytorch_util as ptu from rlkit.data_management.online_vae_replay_buffer import \ OnlineVaeRelabelingBuffer from rlkit.exploration_strategies.base import ( PolicyWrappedWithExplorationStrategy) from rlkit.torch.her.online_vae_joint_algo import OnlineVaeHerJointAlgo from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy from rlkit.torch.td3.td3 import TD3 from rlkit.torch.vae.vae_trainer import ConvVAETrainer preprocess_rl_variant(variant) env = get_envs(variant) es = get_exploration_strategy(variant, env) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'], ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'], ) policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs'], ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) exploring_qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'], ) exploring_qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, **variant['qf_kwargs'], ) exploring_policy = TanhMlpPolicy( input_size=obs_dim, output_size=action_dim, **variant['policy_kwargs'], ) exploring_exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=exploring_policy, ) vae = env.vae replay_buffer = OnlineVaeRelabelingBuffer( vae=vae, env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) variant["algo_kwargs"]["replay_buffer"] = replay_buffer if variant.get('use_replay_buffer_goals', False): env.replay_buffer = replay_buffer env.use_replay_buffer_goals = True vae_trainer_kwargs = variant.get('vae_trainer_kwargs') t = ConvVAETrainer(variant['vae_train_data'], variant['vae_test_data'], vae, beta=variant['online_vae_beta'], **vae_trainer_kwargs) control_algorithm = TD3(env=env, training_env=env, qf1=qf1, qf2=qf2, policy=policy, exploration_policy=exploration_policy, **variant['algo_kwargs']) exploring_algorithm = TD3(env=env, training_env=env, qf1=exploring_qf1, qf2=exploring_qf2, policy=exploring_policy, exploration_policy=exploring_exploration_policy, **variant['algo_kwargs']) assert 'vae_training_schedule' not in variant,\ "Just put it in joint_algo_kwargs" algorithm = OnlineVaeHerJointAlgo(vae=vae, vae_trainer=t, env=env, training_env=env, policy=policy, exploration_policy=exploration_policy, replay_buffer=replay_buffer, algo1=control_algorithm, algo2=exploring_algorithm, algo1_prefix="Control_", algo2_prefix="VAE_Exploration_", observation_key=observation_key, desired_goal_key=desired_goal_key, **variant['joint_algo_kwargs']) algorithm.to(ptu.device) vae.to(ptu.device) if variant.get("save_video", True): policy.train(False) rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=algorithm.max_path_length, observation_key=algorithm.observation_key, desired_goal_key=algorithm.desired_goal_key, ) video_func = get_video_save_func( rollout_function, env, algorithm.eval_policy, variant, ) algorithm.post_train_funcs.append(video_func) algorithm.train()
def train_vae(variant, other_variant, return_data=False): from rlkit.util.ml_util import PiecewiseLinearSchedule from rlkit.torch.vae.conv_vae import ( ConvVAE, ) import rlkit.torch.vae.conv_vae as conv_vae from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.core import logger import rlkit.torch.pytorch_util as ptu from rlkit.pythonplusplus import identity import torch beta = variant["beta"] representation_size = variant["representation_size"] generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) train_data, test_data, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') m = ConvVAE(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, other_variant, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, # save_vae=False, ) if should_save_imgs: t.dump_samples(epoch) t.update_train_weights() logger.save_extra_data(m, 'vae.pkl', mode='pickle') # torch.save(m, other_variant['vae_pkl_path']+'/online_vae.pkl') # easy way:load momdel for via bonus if return_data: return m, train_data, test_data return m
def skewfit_experiment(variant, other_variant): import rlkit.torch.pytorch_util as ptu from rlkit.data_management.online_vae_replay_buffer import \ OnlineVaeRelabelingBuffer from rlkit.torch.networks import FlattenMlp from rlkit.torch.sac.policies import TanhGaussianPolicy from rlkit.torch.vae.vae_trainer import ConvVAETrainer skewfit_preprocess_variant(variant) env = get_envs(variant) uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None) if uniform_dataset_fn: uniform_dataset = uniform_dataset_fn( **variant['generate_uniform_dataset_kwargs']) else: uniform_dataset = None observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size hidden_sizes = variant.get('hidden_sizes', [400, 300]) qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=hidden_sizes, ) vae = env.vae replay_buffer = OnlineVaeRelabelingBuffer( vae=env.vae, env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) vae_trainer = ConvVAETrainer(variant['vae_train_data'], variant['vae_test_data'], env.vae, other_variant, **variant['online_vae_trainer_kwargs']) assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs" max_path_length = variant['max_path_length'] trainer = SACTrainer(env=env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **variant['twin_sac_trainer_kwargs']) trainer = HERTrainer(trainer) eval_path_collector = VAEWrappedEnvPathCollector( variant['evaluation_goal_sampling_mode'], env, MakeDeterministic(policy), max_path_length, other_variant=other_variant, observation_key=observation_key, desired_goal_key=desired_goal_key, ) expl_path_collector = VAEWrappedEnvPathCollector( variant['exploration_goal_sampling_mode'], env, policy, max_path_length, other_variant=other_variant, observation_key=observation_key, desired_goal_key=desired_goal_key, ) algorithm = OnlineVaeAlgorithm( trainer=trainer, exploration_env=env, evaluation_env=env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, vae=vae, vae_trainer=vae_trainer, uniform_dataset=uniform_dataset, max_path_length=max_path_length, **variant['algo_kwargs']) if variant['custom_goal_sampler'] == 'replay_buffer': env.custom_goal_sampler = replay_buffer.sample_buffer_goals algorithm.to(ptu.device) vae.to(ptu.device) algorithm.train()
def train_vae(beta, representation_size, imsize, num_epochs, save_period, generate_vae_dataset_fctn=None, beta_schedule_kwargs=None, decoder_activation=None, vae_kwargs=None, generate_vae_dataset_kwargs=None, algo_kwargs=None, use_spatial_auto_encoder=False, vae_class=None, dump_skew_debug_plots=False): from rlkit.misc.ml_util import PiecewiseLinearSchedule from rlkit.torch.vae.conv_vae import ( ConvVAE, SpatialAutoEncoder, AutoEncoder, ) import rlkit.torch.vae.conv_vae as conv_vae from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.torch.vae.vae_experiment import VAEExperiment from rlkit.pythonplusplus import identity from rlkit.torch.grill.launcher import generate_vae_dataset import torch if vae_kwargs is None: vae_kwargs = {} if generate_vae_dataset_kwargs is None: generate_vae_dataset_kwargs = {} if algo_kwargs is None: algo_kwargs = {} if generate_vae_dataset_fctn is None: generate_vae_dataset_fctn = generate_vae_dataset if vae_class is None: vae_class = ConvVAE if beta_schedule_kwargs is not None: beta_schedule = PiecewiseLinearSchedule(**beta_schedule_kwargs) else: beta_schedule = None if decoder_activation == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = vae_kwargs.get('architecture', None) if not architecture and imsize == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and imsize == 48: architecture = conv_vae.imsize48_default_architecture vae_kwargs['architecture'] = architecture vae_kwargs['imsize'] = imsize if algo_kwargs.get('is_auto_encoder', False): m = AutoEncoder(representation_size, decoder_output_activation=decoder_activation, **vae_kwargs) elif use_spatial_auto_encoder: m = SpatialAutoEncoder(representation_size, decoder_output_activation=decoder_activation, **vae_kwargs) else: m = vae_class(representation_size, decoder_output_activation=decoder_activation, **vae_kwargs) train_data, test_data, info = generate_vae_dataset_fctn( generate_vae_dataset_kwargs) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **algo_kwargs) vae_exp = VAEExperiment(t, num_epochs, save_period, dump_skew_debug_plots) return vae_exp, train_data, test_data
def skewfit_experiment(variant): import rlkit.torch.pytorch_util as ptu from rlkit.data_management.online_vae_replay_buffer \ import OnlineVaeRelabelingBuffer from rlkit.torch.networks import FlattenMlp from rlkit.torch.sac.policies import TanhGaussianPolicy import rlkit.torch.vae.vae_schedules as vae_schedules #### getting parameter for training VAE and RIG env = get_envs(variant) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size hidden_sizes = variant.get('hidden_sizes', [400, 300]) replay_buffer_kwargs = variant.get( 'replay_buffer_kwargs', dict( start_skew_epoch=10, max_size=int(100000), fraction_goals_rollout_goals=0.2, fraction_goals_env_goals=0.5, exploration_rewards_type='None', vae_priority_type='vae_prob', priority_function_kwargs=dict( sampling_method='importance_sampling', decoder_distribution='gaussian_identity_variance', num_latents_to_sample=10, ), power=0, relabeling_goal_sampling_mode='vae_prior', )) online_vae_trainer_kwargs = variant.get('online_vae_trainer_kwargs', dict(beta=20, lr=1e-3)) max_path_length = variant.get('max_path_length', 50) algo_kwargs = variant.get( 'algo_kwargs', dict( batch_size=1024, num_epochs=1000, num_eval_steps_per_epoch=500, num_expl_steps_per_train_loop=500, num_trains_per_train_loop=1000, min_num_steps_before_training=10000, vae_training_schedule=vae_schedules.custom_schedule_2, oracle_data=False, vae_save_period=50, parallel_vae_train=False, )) twin_sac_trainer_kwargs = variant.get( 'twin_sac_trainer_kwargs', dict( discount=0.99, reward_scale=1, soft_target_tau=1e-3, target_update_period=1, # 1 use_automatic_entropy_tuning=True, )) ############################################################################ qf1 = FlattenMlp(input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes) qf2 = FlattenMlp(input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes) target_qf1 = FlattenMlp(input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes) target_qf2 = FlattenMlp(input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes) policy = TanhGaussianPolicy(obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=hidden_sizes) vae = variant['vae_model'] # create a replay buffer for training an online VAE replay_buffer = OnlineVaeRelabelingBuffer( vae=vae, env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **replay_buffer_kwargs) # create an online vae_trainer to train vae on the fly vae_trainer = ConvVAETrainer(variant['vae_train_data'], variant['vae_test_data'], vae, **online_vae_trainer_kwargs) # create a SACTrainer to learn a soft Q-function and appropriate policy trainer = SACTrainer(env=env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **twin_sac_trainer_kwargs) trainer = HERTrainer(trainer) eval_path_collector = VAEWrappedEnvPathCollector( variant.get('evaluation_goal_sampling_mode', 'reset_of_env'), env, MakeDeterministic(policy), max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, ) expl_path_collector = VAEWrappedEnvPathCollector( variant.get('exploration_goal_sampling_mode', 'vae_prior'), env, policy, max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, ) algorithm = OnlineVaeAlgorithm( trainer=trainer, exploration_env=env, evaluation_env=env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, vae=vae, vae_trainer=vae_trainer, max_path_length=max_path_length, **algo_kwargs) if variant['custom_goal_sampler'] == 'replay_buffer': env.custom_goal_sampler = replay_buffer.sample_buffer_goals algorithm.to(ptu.device) vae.to(ptu.device) algorithm.train()
def get_n_train_vae(latent_dim, env, vae_train_epochs, num_image_examples, vae_kwargs, vae_trainer_kwargs, vae_architecture, vae_save_period=10, vae_test_p=.9, decoder_activation='sigmoid', vae_class='VAE', **kwargs): env.goal_sampling_mode = 'test' image_examples = unnormalize_image( env.sample_goals(num_image_examples)['desired_goal']) n = int(num_image_examples * vae_test_p) train_dataset = ImageObservationDataset(image_examples[:n, :]) test_dataset = ImageObservationDataset(image_examples[n:, :]) if decoder_activation == 'sigmoid': decoder_activation = torch.nn.Sigmoid() vae_class = vae_class.lower() if vae_class == 'VAE'.lower(): vae_class = ConvVAE elif vae_class == 'SpatialVAE'.lower(): vae_class = SpatialAutoEncoder else: raise RuntimeError("Invalid VAE Class: {}".format(vae_class)) vae = vae_class(latent_dim, architecture=vae_architecture, decoder_output_activation=decoder_activation, **vae_kwargs) trainer = ConvVAETrainer(vae, **vae_trainer_kwargs) logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) for epoch in range(vae_train_epochs): should_save_imgs = (epoch % vae_save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, vae) logger.save_extra_data(vae, 'vae.pkl', mode='pickle') logger.remove_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('progress.csv', relative_to_snapshot_dir=True) return vae
def train_vae(cfgs, return_data=False): from rlkit.util.ml_util import PiecewiseLinearSchedule from rlkit.torch.vae.conv_vae import ( ConvVAE, ) import rlkit.torch.vae.conv_vae as conv_vae from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.core import logger import rlkit.torch.pytorch_util as ptu from rlkit.pythonplusplus import identity import torch train_data, test_data, info = generate_vae_dataset(cfgs) logger.save_extra_data(info) logger.get_snapshot_dir() # FIXME default gaussian if cfgs.VAE.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = cfgs.VAE.get('architecture', None) if not architecture and cfgs.ENV.get('img_size') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and cfgs.ENV.get('img_size') == 48: architecture = conv_vae.imsize48_default_architecture vae_model = ConvVAE( representation_size=cfgs.VAE.representation_size, architecture=architecture, decoder_output_activation=decoder_activation, input_channels=cfgs.VAE.input_channels, decoder_distribution=cfgs.VAE.decoder_distribution, imsize=cfgs.VAE.img_size, ) vae_model.to(ptu.device) # FIXME the function of beta_schedule? if 'beta_schedule_kwargs' in cfgs.VAE_TRAINER: beta_schedule = PiecewiseLinearSchedule( **cfgs.VAE_TRAINER.beta_schedule_kwargs) else: beta_schedule = None t = ConvVAETrainer(train_data, test_data, vae_model, lr=cfgs.VAE_TRAINER.lr, beta=cfgs.VAE_TRAINER.beta, beta_schedule=beta_schedule) save_period = cfgs.VAE_TRAINER.save_period for epoch in range(cfgs.VAE_TRAINER.num_epochs): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, # save_vae=False, ) if should_save_imgs: t.dump_samples(epoch) t.update_train_weights() logger.save_extra_data(vae_model, 'vae.pkl', mode='pickle') if return_data: return vae_model, train_data, test_data return vae_model
def skewfit_experiment(cfgs): import rlkit.torch.pytorch_util as ptu from rlkit.data_management.online_vae_replay_buffer import \ OnlineVaeRelabelingBuffer from rlkit.torch.networks import FlattenMlp from rlkit.torch.sac.policies import TanhGaussianPolicy from rlkit.torch.vae.vae_trainer import ConvVAETrainer skewfit_preprocess_variant(cfgs) env = get_envs(cfgs) # TODO uniform_dataset_fn = cfgs.GENERATE_VAE_DATASET.get( 'uniform_dataset_generator', None) if uniform_dataset_fn: uniform_dataset = uniform_dataset_fn( **cfgs.GENERATE_VAE_DATASET.generate_uniform_dataset_kwargs) else: uniform_dataset = None observation_key = cfgs.SKEW_FIT.get('observation_key', 'latent_observation') desired_goal_key = cfgs.SKEW_FIT.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = env.action_space.low.size hidden_sizes = cfgs.Q_FUNCTION.get('hidden_sizes', [400, 300]) qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) target_qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=hidden_sizes, ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=cfgs.POLICY.get('hidden_sizes', [400, 300]), ) vae = env.vae replay_buffer = OnlineVaeRelabelingBuffer( vae=env.vae, env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, priority_function_kwargs=cfgs.PRIORITY_FUNCTION, **cfgs.REPLAY_BUFFER) vae_trainer = ConvVAETrainer( cfgs.VAE_TRAINER.train_data, cfgs.VAE_TRAINER.test_data, env.vae, beta=cfgs.VAE_TRAINER.beta, lr=cfgs.VAE_TRAINER.lr, ) # assert 'vae_training_schedule' not in cfgs, "Just put it in algo_kwargs" max_path_length = cfgs.SKEW_FIT.max_path_length trainer = SACTrainer(env=env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **cfgs.TWIN_SAC_TRAINER) trainer = HERTrainer(trainer) eval_path_collector = VAEWrappedEnvPathCollector( cfgs.SKEW_FIT.evaluation_goal_sampling_mode, env, MakeDeterministic(policy), decode_goals=True, # TODO check this observation_key=observation_key, desired_goal_key=desired_goal_key, ) expl_path_collector = VAEWrappedEnvPathCollector( cfgs.SKEW_FIT.exploration_goal_sampling_mode, env, policy, decode_goals=True, observation_key=observation_key, desired_goal_key=desired_goal_key, ) algorithm = OnlineVaeAlgorithm( trainer=trainer, exploration_env=env, evaluation_env=env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, vae=vae, vae_trainer=vae_trainer, uniform_dataset=uniform_dataset, # TODO used in test vae max_path_length=max_path_length, parallel_vae_train=cfgs.VAE_TRAINER.parallel_train, **cfgs.ALGORITHM) if cfgs.SKEW_FIT.custom_goal_sampler == 'replay_buffer': env.custom_goal_sampler = replay_buffer.sample_buffer_goals algorithm.to(ptu.device) vae.to(ptu.device) algorithm.train()
def train_vae(variant, return_data=False): from rlkit.util.ml_util import PiecewiseLinearSchedule from rlkit.torch.vae.conv_vae import ConvVAE import rlkit.torch.vae.conv_vae as conv_vae from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.core import logger import rlkit.torch.pytorch_util as ptu from rlkit.pythonplusplus import identity import torch beta = variant["beta"] representation_size = variant["representation_size"] generate_vae_dataset_fctn = variant.get("generate_vae_data_fctn", generate_vae_dataset) train_data, test_data, info = generate_vae_dataset_fctn( variant["generate_vae_dataset_kwargs"]) logger.save_extra_data(info) logger.get_snapshot_dir() if "beta_schedule_kwargs" in variant: beta_schedule = PiecewiseLinearSchedule( **variant["beta_schedule_kwargs"]) else: beta_schedule = None if variant.get("decoder_activation", None) == "sigmoid": decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant["vae_kwargs"].get("architecture", None) if not architecture and variant.get("imsize") == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get("imsize") == 48: architecture = conv_vae.imsize48_default_architecture variant["vae_kwargs"]["architecture"] = architecture variant["vae_kwargs"]["imsize"] = variant.get("imsize") m = ConvVAE(representation_size, decoder_output_activation=decoder_activation, **variant["vae_kwargs"]) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant["algo_kwargs"]) save_period = variant["save_period"] dump_skew_debug_plots = variant.get("dump_skew_debug_plots", False) for epoch in range(variant["num_epochs"]): should_save_imgs = epoch % save_period == 0 t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, # save_vae=False, ) if should_save_imgs: t.dump_samples(epoch) t.update_train_weights() logger.save_extra_data(m, "vae.pkl", mode="pickle") if return_data: return m, train_data, test_data return m
def train_vae( variant, return_data=False, skewfit_variant=None): # acutally train both the vae and the lstm from rlkit.util.ml_util import PiecewiseLinearSchedule from rlkit.torch.vae.conv_vae import ( ConvVAE, ) import rlkit.torch.vae.conv_vae as conv_vae import ROLL.LSTM_model as LSTM_model from ROLL.LSTM_model import ConvLSTM2 from ROLL.LSTM_trainer import ConvLSTMTrainer from rlkit.torch.vae.vae_trainer import ConvVAETrainer import rlkit.torch.pytorch_util as ptu from rlkit.pythonplusplus import identity import torch seg_pretrain = variant['seg_pretrain'] ori_pretrain = variant['ori_pretrain'] generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) generate_lstm_dataset_fctn = variant.get('generate_lstm_data_fctn') assert generate_lstm_dataset_fctn is not None, "Must provide a custom generate lstm pretraining dataset function!" train_data_lstm, test_data_lstm, info_lstm = generate_lstm_dataset_fctn( variant['generate_lstm_dataset_kwargs'], segmented=True, segmentation_method=skewfit_variant['segmentation_method']) train_data_ori, test_data_ori, info_ori = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') architecture = variant['lstm_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = None # TODO LSTM: wrap a 84 lstm architecutre elif not architecture and variant.get('imsize') == 48: architecture = LSTM_model.imsize48_default_architecture variant['lstm_kwargs']['architecture'] = architecture variant['lstm_kwargs']['imsize'] = variant.get('imsize') train_datas = [ train_data_lstm, train_data_ori, ] test_datas = [ test_data_lstm, test_data_ori, ] names = [ 'lstm_seg_pretrain', 'vae_ori_pretrain', ] vaes = [] env_id = variant['generate_lstm_dataset_kwargs'].get('env_id') assert env_id is not None lstm_pretrain_vae_only = variant.get('lstm_pretrain_vae_only', False) for idx in range(2): train_data, test_data, name = train_datas[idx], test_datas[idx], names[ idx] logger.add_tabular_output('{}_progress.csv'.format(name), relative_to_snapshot_dir=True) if idx == 1: # train the original vae representation_size = variant.get( "vae_representation_size", variant.get('representation_size')) beta = variant.get('vae_beta', variant.get('beta')) m = ConvVAE(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) else: # train the segmentation lstm lstm_version = variant.get('lstm_version', 2) if lstm_version == 2: lstm_class = ConvLSTM2 representation_size = variant.get( "lstm_representation_size", variant.get('representation_size')) beta = variant.get('lstm_beta', variant.get('beta')) m = lstm_class(representation_size, decoder_output_activation=decoder_activation, **variant['lstm_kwargs']) t = ConvLSTMTrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) m.to(ptu.device) vaes.append(m) print("test data len: ", len(test_data)) print("train data len: ", len(train_data)) save_period = variant['save_period'] pjhome = os.environ['PJHOME'] if env_id == 'SawyerPushHurdle-v0' and osp.exists( osp.join( pjhome, 'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5.npy'.format( 'SawyerPushHurdle-v0', 'seg-color', '500'))): data_file_path = osp.join( pjhome, 'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5.npy'.format(env_id, 'seg-color', 500)) puck_pos_path = osp.join( pjhome, 'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5-puck-pos.npy'.format(env_id, 'seg-color', 500)) all_data = np.load(data_file_path) puck_pos = np.load(puck_pos_path) all_data = normalize_image(all_data.copy()) obj_states = puck_pos else: all_data = np.concatenate([train_data_lstm, test_data_lstm], axis=0) all_data = normalize_image(all_data.copy()) obj_states = info_lstm['obj_state'] obj = 'door' if 'Door' in env_id else 'puck' num_epochs = variant['num_lstm_epochs'] if idx == 0 else variant[ 'num_vae_epochs'] if (idx == 0 and seg_pretrain) or (idx == 1 and ori_pretrain): for epoch in range(num_epochs): should_save_imgs = (epoch % save_period == 0) if idx == 0: # only LSTM trainer has 'only_train_vae' argument t.train_epoch(epoch, only_train_vae=lstm_pretrain_vae_only) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_prefix='r_' + name, only_train_vae=lstm_pretrain_vae_only) else: t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, save_prefix='r_' + name, ) if should_save_imgs: t.dump_samples(epoch, save_prefix='s_' + name) if idx == 0: compare_latent_distance( m, all_data, obj_states, obj_name=obj, save_dir=logger.get_snapshot_dir(), save_name='lstm_latent_distance_{}.png'.format( epoch)) test_lstm_traj( env_id, m, save_path=logger.get_snapshot_dir(), save_name='lstm_test_traj_{}.png'.format(epoch)) test_masked_traj_lstm( env_id, m, save_dir=logger.get_snapshot_dir(), save_name='masked_test_{}.png'.format(epoch)) t.update_train_weights() logger.save_extra_data(m, '{}.pkl'.format(name), mode='pickle') logger.remove_tabular_output('{}_progress.csv'.format(name), relative_to_snapshot_dir=True) if idx == 0 and variant.get("only_train_lstm", False): exit() if return_data: return vaes, train_datas, test_datas return m