def add_images(env, base_distribution): if use_image_observations: video_env = InsertImageEnv( env, renderer=video_renderer, image_key='video_observation', ) image_goal_distribution = base_distribution else: video_env = InsertImageEnv( env, renderer=video_renderer, image_key='image_observation', ) state_env = env.env image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=base_distribution, image_goal_key='image_desired_goal', renderer=video_renderer, ) return ContextualEnv( video_env, context_distribution=image_goal_distribution, reward_fn=reward_fn, observation_key=observation_key_for_rl, update_env_info_fn=delete_info, )
def analyze_from_vae(snapshot_path, latent_observation_key='latent_observation', mean_key='latent_mean', covariance_key='latent_covariance', image_observation_key='image_observation', **kwargs): data = torch.load(open(snapshot_path, "rb")) variant_path = snapshot_path.replace('params.pt', 'variant.json') print_settings(variant_path) vae = data['trainer/vae'] state_env = gym.make('OneObject-PickAndPlace-BigBall-RandomInit-2D-v1') renderer = EnvRenderer() sets = make_custom_sets(state_env, renderer) reward_fn, _ = rewards.create_normal_likelihood_reward_fns( latent_observation_key=latent_observation_key, mean_key=mean_key, covariance_key=covariance_key, reward_fn_kwargs=dict( drop_log_det_term=True, sqrt_reward=True, ), ) img_env = InsertImageEnv(state_env, renderer=renderer) env = DictEncoderWrappedEnv( img_env, vae, encoder_input_key='image_observation', encoder_output_remapping={'posterior_mean': 'latent_observation'}, ) analyze(sets, vae, env, **kwargs)
def add_images(env, base_distribution): if use_image_observations: img_env = env image_goal_distribution = base_distribution else: state_env = env.env image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=base_distribution, image_goal_key='image_desired_goal', renderer=video_renderer, ) img_env = InsertImageEnv(state_env, renderer=video_renderer) img_env = InsertDebugImagesEnv( img_env, obj1_sweep_renderers, compute_shared_data=obj1_sweeper, ) img_env = InsertDebugImagesEnv( img_env, obj0_sweep_renderers, compute_shared_data=obj0_sweeper, ) return ContextualEnv( img_env, context_distribution=image_goal_distribution, reward_fn=reward_fn, observation_key=observation_key_for_rl, update_env_info_fn=delete_info, )
def setup_env(state_env, encoder, reward_fn): goal_distribution = GoalDictDistributionFromMultitaskEnv( state_env, desired_goal_keys=[state_desired_goal_key], ) if use_image_observations: goal_distribution = AddImageDistribution( env=state_env, base_distribution=goal_distribution, image_goal_key=img_desired_goal_key, renderer=env_renderer, ) base_env = InsertImageEnv(state_env, renderer=env_renderer) goal_distribution = PresampledDistribution( goal_distribution, num_presampled_goals) goal_distribution = EncodedGoalDictDistribution( goal_distribution, encoder=encoder, keys_to_keep=[state_desired_goal_key, img_desired_goal_key], encoder_input_key=img_desired_goal_key, encoder_output_key=latent_desired_goal_key, ) else: base_env = state_env goal_distribution = EncodedGoalDictDistribution( goal_distribution, encoder=encoder, keys_to_keep=[state_desired_goal_key], encoder_input_key=state_desired_goal_key, encoder_output_key=latent_desired_goal_key, ) goal_distribution = MaskedGoalDictDistribution( goal_distribution, mask_key=mask_key, mask_dim=latent_dim, distribution_type='one_hot_masks', ) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( state_env.goal_conditioned_diagnostics, desired_goal_key=state_desired_goal_key, observation_key=state_observation_key, ) env = ContextualEnv( base_env, context_distribution=goal_distribution, reward_fn=reward_fn, contextual_diagnostics_fns=[state_diag_fn], update_env_info_fn=delete_info, **contextual_env_kwargs, ) return env, goal_distribution
def add_images(env, context_distribution): state_env = env.env img_env = InsertImageEnv( state_env, renderer=renderer, image_key='image_observation', ) return ContextualEnv( img_env, context_distribution=context_distribution, reward_fn=eval_reward, observation_key=observation_key, update_env_info_fn=delete_info, )
def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode, renderer): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) state_env.goal_sampling_mode = goal_sampling_mode state_goal_distribution = GoalDictDistributionFromMultitaskEnv( state_env, desired_goal_keys=[state_desired_goal_key], ) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( state_env.goal_conditioned_diagnostics, desired_goal_key=state_desired_goal_key, observation_key=state_observation_key, ) image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_goal_distribution, image_goal_key=img_desired_goal_key, renderer=renderer, ) goal_distribution = PresampledDistribution(image_goal_distribution, 5000) img_env = InsertImageEnv(state_env, renderer=renderer) if reward_type == 'state_distance': reward_fn = ContextualRewardFnFromMultitaskEnv( env=state_env, achieved_goal_from_observation=IndexIntoAchievedGoal( 'state_observation'), desired_goal_key=state_desired_goal_key, achieved_goal_key=state_achieved_goal_key, ) elif reward_type == 'pixel_distance': reward_fn = NegativeL2Distance( achieved_goal_from_observation=IndexIntoAchievedGoal( img_observation_key), desired_goal_key=img_desired_goal_key, ) else: raise ValueError(reward_type) env = ContextualEnv( img_env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=img_observation_key, contextual_diagnostics_fns=[state_diag_fn], update_env_info_fn=delete_info, ) return env, goal_distribution, reward_fn
def add_images(env, state_distribution): state_env = env.env image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_distribution, image_goal_key='image_desired_goal', renderer=renderer, ) img_env = InsertImageEnv(state_env, renderer=renderer) return ContextualEnv( img_env, context_distribution=image_goal_distribution, reward_fn=eval_reward, observation_key=observation_key, update_env_info_fn=delete_info, )
def get_video_func( env, policy, tag, ): renderer = EnvRenderer(**renderer_kwargs) state_goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) image_goal_distribution = AddImageDistribution( env=env, base_distribution=state_goal_distribution, image_goal_key="image_desired_goal", renderer=renderer, ) img_env = InsertImageEnv(env, renderer=renderer) rollout_function = partial( rf.multitask_rollout, max_path_length=variant["max_path_length"], observation_key=observation_key, desired_goal_key=desired_goal_key, return_dict_obs=True, ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal( observation_key), desired_goal_key=desired_goal_key, achieved_goal_key="state_achieved_goal", ) contextual_env = ContextualEnv( img_env, context_distribution=image_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, ) video_func = get_save_video_function( rollout_function, contextual_env, policy, tag=tag, imsize=renderer.width, image_format="CWH", **save_video_kwargs, ) return video_func
def image_based_goal_conditioned_sac_experiment( max_path_length, qf_kwargs, sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, algo_kwargs, cnn_kwargs, policy_type='tanh-normal', env_id=None, env_class=None, env_kwargs=None, exploration_policy_kwargs=None, evaluation_goal_sampling_mode=None, exploration_goal_sampling_mode=None, reward_type='state_distance', env_renderer_kwargs=None, # Data augmentations apply_random_crops=False, random_crop_pixel_shift=4, # Video parameters save_video=True, save_video_kwargs=None, video_renderer_kwargs=None, ): if exploration_policy_kwargs is None: exploration_policy_kwargs = {} if not save_video_kwargs: save_video_kwargs = {} if not env_renderer_kwargs: env_renderer_kwargs = {} if not video_renderer_kwargs: video_renderer_kwargs = {} img_observation_key = 'image_observation' img_desired_goal_key = 'image_desired_goal' state_observation_key = 'state_observation' state_desired_goal_key = 'state_desired_goal' state_achieved_goal_key = 'state_achieved_goal' sample_context_from_obs_dict_fn = RemapKeyFn({ 'image_desired_goal': 'image_observation', 'state_desired_goal': 'state_observation', }) def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode, renderer): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) state_env.goal_sampling_mode = goal_sampling_mode state_goal_distribution = GoalDictDistributionFromMultitaskEnv( state_env, desired_goal_keys=[state_desired_goal_key], ) state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics( state_env.goal_conditioned_diagnostics, desired_goal_key=state_desired_goal_key, observation_key=state_observation_key, ) image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_goal_distribution, image_goal_key=img_desired_goal_key, renderer=renderer, ) goal_distribution = PresampledDistribution(image_goal_distribution, 5000) img_env = InsertImageEnv(state_env, renderer=renderer) if reward_type == 'state_distance': reward_fn = ContextualRewardFnFromMultitaskEnv( env=state_env, achieved_goal_from_observation=IndexIntoAchievedGoal( 'state_observation'), desired_goal_key=state_desired_goal_key, achieved_goal_key=state_achieved_goal_key, ) elif reward_type == 'pixel_distance': reward_fn = NegativeL2Distance( achieved_goal_from_observation=IndexIntoAchievedGoal( img_observation_key), desired_goal_key=img_desired_goal_key, ) else: raise ValueError(reward_type) env = ContextualEnv( img_env, context_distribution=goal_distribution, reward_fn=reward_fn, observation_key=img_observation_key, contextual_diagnostics_fns=[state_diag_fn], update_env_info_fn=delete_info, ) return env, goal_distribution, reward_fn env_renderer = EnvRenderer(**env_renderer_kwargs) expl_env, expl_context_distrib, expl_reward = setup_contextual_env( env_id, env_class, env_kwargs, exploration_goal_sampling_mode, env_renderer) eval_env, eval_context_distrib, eval_reward = setup_contextual_env( env_id, env_class, env_kwargs, evaluation_goal_sampling_mode, env_renderer) action_dim = expl_env.action_space.low.size if env_renderer.output_image_format == 'WHC': img_width, img_height, img_num_channels = ( expl_env.observation_space[img_observation_key].shape) elif env_renderer.output_image_format == 'CHW': img_num_channels, img_height, img_width = ( expl_env.observation_space[img_observation_key].shape) else: raise ValueError(env_renderer.output_image_format) def create_qf(): cnn = BasicCNN(input_width=img_width, input_height=img_height, input_channels=img_num_channels, **cnn_kwargs) joint_cnn = ApplyConvToStateAndGoalImage(cnn) return basic.MultiInputSequential( ApplyToObs(joint_cnn), basic.FlattenEachParallel(), ConcatMlp(input_size=joint_cnn.output_size + action_dim, output_size=1, **qf_kwargs)) qf1 = create_qf() qf2 = create_qf() target_qf1 = create_qf() target_qf2 = create_qf() cnn = BasicCNN(input_width=img_width, input_height=img_height, input_channels=img_num_channels, **cnn_kwargs) joint_cnn = ApplyConvToStateAndGoalImage(cnn) policy_obs_dim = joint_cnn.output_size if policy_type == 'normal': obs_processor = nn.Sequential( joint_cnn, basic.Flatten(), MultiHeadedMlp(input_size=policy_obs_dim, output_sizes=[action_dim, action_dim], **policy_kwargs)) policy = PolicyFromDistributionGenerator(Gaussian(obs_processor)) elif policy_type == 'tanh-normal': obs_processor = nn.Sequential( joint_cnn, basic.Flatten(), MultiHeadedMlp(input_size=policy_obs_dim, output_sizes=[action_dim, action_dim], **policy_kwargs)) policy = PolicyFromDistributionGenerator(TanhGaussian(obs_processor)) elif policy_type == 'normal-tanh-mean': obs_processor = nn.Sequential( joint_cnn, basic.Flatten(), MultiHeadedMlp(input_size=policy_obs_dim, output_sizes=[action_dim, action_dim], output_activations=['tanh', 'identity'], **policy_kwargs)) policy = PolicyFromDistributionGenerator(Gaussian(obs_processor)) else: raise ValueError("Unknown policy type: {}".format(policy_type)) if apply_random_crops: pad = BatchPad( env_renderer.output_image_format, random_crop_pixel_shift, random_crop_pixel_shift, ) crop = JointRandomCrop( env_renderer.output_image_format, env_renderer.image_shape, ) def concat_context_to_obs(batch, *args, **kwargs): obs = batch['observations'] next_obs = batch['next_observations'] context = batch[img_desired_goal_key] obs_padded = pad(obs) next_obs_padded = pad(next_obs) context_padded = pad(context) obs_aug, context_aug = crop(obs_padded, context_padded) next_obs_aug, next_context_aug = crop(next_obs_padded, context_padded) batch['observations'] = np.concatenate([obs_aug, context_aug], axis=1) batch['next_observations'] = np.concatenate( [next_obs_aug, next_context_aug], axis=1) return batch else: def concat_context_to_obs(batch, *args, **kwargs): obs = batch['observations'] next_obs = batch['next_observations'] context = batch[img_desired_goal_key] batch['observations'] = np.concatenate([obs, context], axis=1) batch['next_observations'] = np.concatenate([next_obs, context], axis=1) return batch replay_buffer = ContextualRelabelingReplayBuffer( env=eval_env, context_keys=[img_desired_goal_key, state_desired_goal_key], observation_key=img_observation_key, observation_keys=[img_observation_key, state_observation_key], context_distribution=eval_context_distrib, sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn, reward_fn=eval_reward, post_process_batch_fn=concat_context_to_obs, **replay_buffer_kwargs) trainer = SACTrainer(env=expl_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **sac_trainer_kwargs) eval_path_collector = ContextualPathCollector( eval_env, MakeDeterministic(policy), observation_key=img_observation_key, context_keys_for_policy=[img_desired_goal_key], ) exploration_policy = create_exploration_policy(expl_env, policy, **exploration_policy_kwargs) expl_path_collector = ContextualPathCollector( expl_env, exploration_policy, observation_key=img_observation_key, context_keys_for_policy=[img_desired_goal_key], ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs) algorithm.to(ptu.device) if save_video: rollout_function = partial( rf.contextual_rollout, max_path_length=max_path_length, observation_key=img_observation_key, context_keys_for_policy=[img_desired_goal_key], ) video_renderer = EnvRenderer(**video_renderer_kwargs) video_eval_env = InsertImageEnv(eval_env, renderer=video_renderer, image_key='video_observation') video_expl_env = InsertImageEnv(expl_env, renderer=video_renderer, image_key='video_observation') video_eval_env = ContextualEnv( video_eval_env, context_distribution=eval_env.context_distribution, reward_fn=lambda *_: np.array([0]), observation_key=img_observation_key, ) video_expl_env = ContextualEnv( video_expl_env, context_distribution=expl_env.context_distribution, reward_fn=lambda *_: np.array([0]), observation_key=img_observation_key, ) eval_video_func = get_save_video_function( rollout_function, video_eval_env, MakeDeterministic(policy), tag="eval", imsize=video_renderer.image_shape[1], image_formats=[ env_renderer.output_image_format, env_renderer.output_image_format, video_renderer.output_image_format, ], keys_to_show=[ 'image_desired_goal', 'image_observation', 'video_observation' ], **save_video_kwargs) expl_video_func = get_save_video_function( rollout_function, video_expl_env, exploration_policy, tag="xplor", imsize=video_renderer.image_shape[1], image_formats=[ env_renderer.output_image_format, env_renderer.output_image_format, video_renderer.output_image_format, ], keys_to_show=[ 'image_desired_goal', 'image_observation', 'video_observation' ], **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(expl_video_func) algorithm.train()
def get_img_env(env): renderer = EnvRenderer(**variant["renderer_kwargs"]) img_env = InsertImageEnv(GymToMultiEnv(env), renderer=renderer)
def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs, goal_sampling_mode): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs) img_env = InsertImageEnv(state_env, renderer=renderer) encoded_env = EncoderWrappedEnv( img_env, model, dict(image_observation="latent_observation", ), ) if goal_sampling_mode == "vae_prior": latent_goal_distribution = PriorDistribution( model.representation_size, desired_goal_key, ) diagnostics = StateImageGoalDiagnosticsFn({}, ) elif goal_sampling_mode == "reset_of_env": state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) state_goal_distribution = GoalDictDistributionFromMultitaskEnv( state_goal_env, desired_goal_keys=[state_goal_key], ) image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_goal_distribution, image_goal_key=image_goal_key, renderer=renderer, ) latent_goal_distribution = AddLatentDistribution( image_goal_distribution, image_goal_key, desired_goal_key, model, ) if hasattr(state_goal_env, 'goal_conditioned_diagnostics'): diagnostics = GoalConditionedDiagnosticsToContextualDiagnostics( state_goal_env.goal_conditioned_diagnostics, desired_goal_key=state_goal_key, observation_key=state_observation_key, ) else: state_goal_env.get_contextual_diagnostics diagnostics = state_goal_env.get_contextual_diagnostics else: raise NotImplementedError('unknown goal sampling method: %s' % goal_sampling_mode) reward_fn = DistanceRewardFn( observation_key=observation_key, desired_goal_key=desired_goal_key, ) env = ContextualEnv( encoded_env, context_distribution=latent_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diagnostics], ) return env, latent_goal_distribution, reward_fn
def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs, goal_sampling_mode, presampled_goals_path): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs) img_env = InsertImageEnv(state_env, renderer=renderer) # encoded_env = EncoderWrappedEnv( # img_env, # model, # dict(image_observation="latent_observation", ), # ) # if goal_sampling_mode == "vae_prior": # latent_goal_distribution = PriorDistribution( # model.representation_size, # desired_goal_key, # ) # diagnostics = StateImageGoalDiagnosticsFn({}, ) # elif goal_sampling_mode == "presampled": # diagnostics = state_env.get_contextual_diagnostics # image_goal_distribution = PresampledPathDistribution( # presampled_goals_path, # ) # latent_goal_distribution = AddLatentDistribution( # image_goal_distribution, # image_goal_key, # desired_goal_key, # model, # ) # elif goal_sampling_mode == "reset_of_env": # state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) # state_goal_distribution = GoalDictDistributionFromMultitaskEnv( # state_goal_env, # desired_goal_keys=[state_goal_key], # ) # image_goal_distribution = AddImageDistribution( # env=state_env, # base_distribution=state_goal_distribution, # image_goal_key=image_goal_key, # renderer=renderer, # ) # latent_goal_distribution = AddLatentDistribution( # image_goal_distribution, # image_goal_key, # desired_goal_key, # model, # ) # no_goal_distribution = PriorDistribution( # representation_size=0, # key="no_goal", # ) # diagnostics = state_goal_env.get_contextual_diagnostics # else: # error diagnostics = StateImageGoalDiagnosticsFn({}, ) no_goal_distribution = PriorDistribution( representation_size=0, key="no_goal", ) reward_fn = GraspingRewardFn( # img_env, # state_env, # observation_key=observation_key, # desired_goal_key=desired_goal_key, # **reward_kwargs ) env = ContextualEnv( img_env, # state_env, context_distribution=no_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diagnostics], ) return env, no_goal_distribution, reward_fn
def experiment(variant): render = variant.get("render", False) debug = variant.get("debug", False) vae_path = variant.get("vae_path", False) process_args(variant) env_class = variant.get("env_class") env_kwargs = variant.get("env_kwargs") env_id = variant.get("env_id") # expl_env = env_class(**env_kwargs) # eval_env = env_class(**env_kwargs) expl_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) eval_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) env = eval_env if variant.get('sparse_reward', False): expl_env = RewardWrapperEnv(expl_env, compute_hand_sparse_reward) eval_env = RewardWrapperEnv(eval_env, compute_hand_sparse_reward) if variant.get("vae_path", False): vae = load_local_or_remote_file(vae_path) variant['path_loader_kwargs']['model_path'] = vae_path renderer = EnvRenderer(**variant.get("renderer_kwargs", {})) expl_env = VQVAEWrappedEnv(InsertImageEnv(expl_env, renderer=renderer), vae, reward_params=variant.get( "reward_params", {}), **variant.get('vae_wrapped_env_kwargs', {})) eval_env = VQVAEWrappedEnv(InsertImageEnv(eval_env, renderer=renderer), vae, reward_params=variant.get( "reward_params", {}), **variant.get('vae_wrapped_env_kwargs', {})) env = eval_env variant['path_loader_kwargs']['env'] = env if variant.get('add_env_demos', False): variant["path_loader_kwargs"]["demo_paths"].append( variant["env_demo_path"]) if variant.get('add_env_offpolicy_data', False): variant["path_loader_kwargs"]["demo_paths"].append( variant["env_offpolicy_data_path"]) if variant.get("use_masks", False): mask_wrapper_kwargs = variant.get("mask_wrapper_kwargs", dict()) expl_mask_distribution_kwargs = variant[ "expl_mask_distribution_kwargs"] expl_mask_distribution = DiscreteDistribution( **expl_mask_distribution_kwargs) expl_env = RewardMaskWrapper(env, expl_mask_distribution, **mask_wrapper_kwargs) eval_mask_distribution_kwargs = variant[ "eval_mask_distribution_kwargs"] eval_mask_distribution = DiscreteDistribution( **eval_mask_distribution_kwargs) eval_env = RewardMaskWrapper(env, eval_mask_distribution, **mask_wrapper_kwargs) env = eval_env if variant.get("pretrained_algorithm_path", False): resume(variant) return path_loader_kwargs = variant.get("path_loader_kwargs", {}) stack_obs = path_loader_kwargs.get("stack_obs", 1) if stack_obs > 1: expl_env = StackObservationEnv(expl_env, stack_obs=stack_obs) eval_env = StackObservationEnv(eval_env, stack_obs=stack_obs) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = variant.get('achieved_goal_key', 'latent_achieved_goal') obs_dim = (env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size) action_dim = eval_env.action_space.low.size if hasattr(expl_env, 'info_sizes'): env_info_sizes = expl_env.info_sizes else: env_info_sizes = dict() replay_buffer_kwargs = dict( env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, ) replay_buffer_kwargs.update(variant.get('replay_buffer_kwargs', dict())) replay_buffer = ConcatToObsWrapper( ObsDictRelabelingBuffer(**replay_buffer_kwargs), [ "resampled_goals", ], ) replay_buffer_kwargs.update( variant.get('demo_replay_buffer_kwargs', dict())) demo_train_buffer = ConcatToObsWrapper( ObsDictRelabelingBuffer(**replay_buffer_kwargs), [ "resampled_goals", ], ) demo_test_buffer = ConcatToObsWrapper( ObsDictRelabelingBuffer(**replay_buffer_kwargs), [ "resampled_goals", ], ) M = variant['layer_size'] qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) target_qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) target_qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) policy_class = variant.get("policy_class", TanhGaussianPolicy) policy = policy_class( obs_dim=obs_dim, action_dim=action_dim, **variant['policy_kwargs'], ) expl_policy = policy exploration_kwargs = variant.get('exploration_kwargs', {}) if exploration_kwargs: if exploration_kwargs.get("deterministic_exploration", False): expl_policy = MakeDeterministic(policy) exploration_strategy = exploration_kwargs.get("strategy", None) if exploration_strategy is None: pass elif exploration_strategy == 'ou': es = OUStrategy( action_space=expl_env.action_space, max_sigma=exploration_kwargs['noise'], min_sigma=exploration_kwargs['noise'], ) expl_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=expl_policy, ) elif exploration_strategy == 'gauss_eps': es = GaussianAndEpislonStrategy( action_space=expl_env.action_space, max_sigma=exploration_kwargs['noise'], min_sigma=exploration_kwargs['noise'], # constant sigma epsilon=0, ) expl_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=expl_policy, ) else: error trainer = AWACTrainer(env=eval_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **variant['trainer_kwargs']) if variant['collection_mode'] == 'online': expl_path_collector = MdpStepCollector( expl_env, policy, ) algorithm = TorchOnlineRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=variant['max_path_length'], batch_size=variant['batch_size'], num_epochs=variant['num_epochs'], num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'], num_expl_steps_per_train_loop=variant[ 'num_expl_steps_per_train_loop'], num_trains_per_train_loop=variant['num_trains_per_train_loop'], min_num_steps_before_training=variant[ 'min_num_steps_before_training'], ) else: eval_path_collector = GoalConditionedPathCollector( eval_env, MakeDeterministic(policy), observation_key=observation_key, desired_goal_key=desired_goal_key, render=render, goal_sampling_mode=variant.get("goal_sampling_mode", None), ) expl_path_collector = GoalConditionedPathCollector( expl_env, policy, observation_key=observation_key, desired_goal_key=desired_goal_key, render=render, goal_sampling_mode=variant.get("goal_sampling_mode", None), ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=variant['max_path_length'], batch_size=variant['batch_size'], num_epochs=variant['num_epochs'], num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'], num_expl_steps_per_train_loop=variant[ 'num_expl_steps_per_train_loop'], num_trains_per_train_loop=variant['num_trains_per_train_loop'], min_num_steps_before_training=variant[ 'min_num_steps_before_training'], ) algorithm.to(ptu.device) if variant.get("save_video", True): video_func = VideoSaveFunction( env, variant, ) #algorithm.post_train_funcs.append(video_func) algorithm.post_train_funcs.append(video_func) # if variant.get("save_video", False): # from rlkit.visualization.video import VideoSaveFunction # renderer_kwargs = variant.get("renderer_kwargs", {}) # save_video_kwargs = variant.get("save_video_kwargs", {}) # def get_video_func( # env, # policy, # tag, # ): # renderer = EnvRenderer(**renderer_kwargs) # state_goal_distribution = GoalDictDistributionFromMultitaskEnv( # env, # desired_goal_keys=[desired_goal_key], # ) # image_goal_distribution = AddImageDistribution( # env=env, # base_distribution=state_goal_distribution, # image_goal_key='image_desired_goal', # renderer=renderer, # ) # img_env = InsertImageEnv(env, renderer=renderer) # rollout_function = partial( # rf.multitask_rollout, # max_path_length=variant['max_path_length'], # observation_key=observation_key, # desired_goal_key=desired_goal_key, # return_dict_obs=True, # ) # reward_fn = ContextualRewardFnFromMultitaskEnv( # env=env, # achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key), # desired_goal_key=desired_goal_key, # achieved_goal_key="state_achieved_goal", # ) # contextual_env = ContextualEnv( # img_env, # context_distribution=image_goal_distribution, # reward_fn=reward_fn, # observation_key=observation_key, # ) # video_func = get_save_video_function( # rollout_function, # contextual_env, # policy, # tag=tag, # imsize=renderer.width, # image_format='CWH', # **save_video_kwargs # ) # return video_func # expl_video_func = get_video_func(expl_env, expl_policy, "expl") # eval_video_func = get_video_func(eval_env, MakeDeterministic(policy), "eval") # algorithm.post_train_funcs.append(eval_video_func) # algorithm.post_train_funcs.append(expl_video_func) if variant.get('save_paths', False): algorithm.post_train_funcs.append(save_paths) if variant.get('load_demos', False): path_loader_class = variant.get('path_loader_class', MDPPathLoader) path_loader = path_loader_class(trainer, replay_buffer=replay_buffer, demo_train_buffer=demo_train_buffer, demo_test_buffer=demo_test_buffer, **path_loader_kwargs) path_loader.load_demos() if variant.get('pretrain_policy', False): trainer.pretrain_policy_with_bc( policy, demo_train_buffer, demo_test_buffer, trainer.bc_num_pretrain_steps, ) if variant.get('pretrain_rl', False): trainer.pretrain_q_with_bc_data() if variant.get('save_pretrained_algorithm', False): p_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.p') pt_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.pt') data = algorithm._get_snapshot() data['algorithm'] = algorithm torch.save(data, open(pt_path, "wb")) torch.save(data, open(p_path, "wb")) algorithm.train()
def contextual_env_distrib_and_reward( vae, sets: typing.List[Set], state_env, renderer, reward_fn_kwargs, use_ground_truth_reward, state_observation_key, latent_observation_key, example_image_key, set_description_key, observation_key, image_observation_key, rig_goal_setter_kwargs, oracle_rig_goal=False, ): img_env = InsertImageEnv(state_env, renderer=renderer) encoded_env = EncoderWrappedEnv( img_env, vae, step_keys_map={image_observation_key: latent_observation_key}, ) if oracle_rig_goal: context_env_class = InitStateConditionedContextualEnv goal_distribution_params_distribution = (OracleRIGMeanSetter( sets, vae, example_image_key, env=state_env, renderer=renderer, cycle_for_batch_size_1=True, **rig_goal_setter_kwargs)) else: context_env_class = ContextualEnv goal_distribution_params_distribution = ( LatentGoalDictDistributionFromSet( sets, vae, example_image_key, cycle_for_batch_size_1=True, )) if use_ground_truth_reward: reward_fn, unbatched_reward_fn = create_ground_truth_set_rewards_fns( sets, goal_distribution_params_distribution.set_index_key, state_observation_key, ) else: reward_fn, unbatched_reward_fn = create_normal_likelihood_reward_fns( latent_observation_key, goal_distribution_params_distribution.mean_key, goal_distribution_params_distribution.covariance_key, reward_fn_kwargs, ) set_diagnostics = SetDiagnostics( set_description_key=set_description_key, set_index_key=goal_distribution_params_distribution.set_index_key, observation_key=state_observation_key, ) env = context_env_class( encoded_env, context_distribution=goal_distribution_params_distribution, reward_fn=reward_fn, unbatched_reward_fn=unbatched_reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[ # goal_diagnostics, set_diagnostics, ], update_env_info_fn=delete_info, ) return env, goal_distribution_params_distribution, reward_fn