def experiment(variant): # create multi-task environment and sample tasks env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params'])) tasks = env.get_all_task_idx() obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) reward_dim = 1 # instantiate networks latent_dim = variant['latent_size'] context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant[ 'algo_params'][ 'use_next_obs_in_context'] else obs_dim + action_dim + reward_dim context_encoder_output_dim = latent_dim * 2 if variant['algo_params'][ 'use_information_bottleneck'] else latent_dim net_size = variant['net_size'] recurrent = variant['algo_params']['recurrent'] encoder_model = RecurrentEncoder if recurrent else MlpEncoder context_encoder = encoder_model( hidden_sizes=[200, 200, 200], input_size=context_encoder_input_dim, output_size=context_encoder_output_dim, ) qf1 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim, output_size=1, ) qf2 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim, output_size=1, ) vf = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + latent_dim, output_size=1, ) policy = TanhGaussianPolicy( hidden_sizes=[net_size, net_size, net_size], obs_dim=obs_dim + latent_dim, latent_dim=latent_dim, action_dim=action_dim, ) agent = PEARLAgent(latent_dim, context_encoder, policy, **variant['algo_params']) algorithm = PEARLSoftActorCritic( env=env, train_tasks=list(tasks[:variant['n_train_tasks']]), eval_tasks=list(tasks[-variant['n_eval_tasks']:]), nets=[agent, qf1, qf2, vf], latent_dim=latent_dim, **variant['algo_params']) # optionally load pre-trained weights if variant['path_to_weights'] is not None: path = variant['path_to_weights'] context_encoder.load_state_dict( torch.load(os.path.join(path, 'context_encoder.pth'))) qf1.load_state_dict(torch.load(os.path.join(path, 'qf1.pth'))) qf2.load_state_dict(torch.load(os.path.join(path, 'qf2.pth'))) vf.load_state_dict(torch.load(os.path.join(path, 'vf.pth'))) # TODO hacky, revisit after model refactor algorithm.networks[-2].load_state_dict( torch.load(os.path.join(path, 'target_vf.pth'))) policy.load_state_dict(torch.load(os.path.join(path, 'policy.pth'))) # optional GPU mode ptu.set_gpu_mode(variant['util_params']['use_gpu'], variant['util_params']['gpu_id']) if ptu.gpu_enabled(): algorithm.to() # debugging triggers a lot of printing and logs to a debug directory DEBUG = variant['util_params']['debug'] os.environ['DEBUG'] = str(int(DEBUG)) # create logging directory # TODO support Docker exp_id = 'debug' if DEBUG else None experiment_log_dir = setup_logger( variant['env_name'], variant=variant, exp_id=exp_id, base_log_dir=variant['util_params']['base_log_dir']) # optionally save eval trajectories as pkl files if variant['algo_params']['dump_eval_paths']: pickle_dir = experiment_log_dir + '/eval_trajectories' pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True) # run the algorithm algorithm.train()
def experiment(variant): domain = variant['domain'] seed = variant['seed'] exp_mode = variant['exp_mode'] max_path_length = variant['algo_params']['max_path_length'] bcq_interactions = variant['bcq_interactions'] num_tasks = variant['num_tasks'] filename = f'./goals/{domain}-{exp_mode}-goals.pkl' idx_list, train_goals, wd_goals, ood_goals = pickle.load( open(filename, 'rb')) idx_list = idx_list[:num_tasks] sub_buffer_dir = f"buffers/{domain}/{exp_mode}/max_path_length_{max_path_length}/interactions_{bcq_interactions}k/seed_{seed}" buffer_dir = os.path.join(variant['data_models_root'], sub_buffer_dir) print("Buffer directory: " + buffer_dir) # Load buffer bcq_buffers = [] buffer_loader_id_list = [] for i, idx in enumerate(idx_list): bname = f'goal_0{idx}.zip_pkl' if idx < 10 else f'goal_{idx}.zip_pkl' filename = os.path.join(buffer_dir, bname) rp_buffer = ReplayBuffer.remote( index=i, seed=seed, num_trans_context=variant['num_trans_context'], in_mdp_batch_size=variant['in_mdp_batch_size'], ) buffer_loader_id_list.append(rp_buffer.load_from_gzip.remote(filename)) bcq_buffers.append(rp_buffer) ray.get(buffer_loader_id_list) assert len(bcq_buffers) == len(idx_list) train_buffer = MultiTaskReplayBuffer(bcq_buffers_list=bcq_buffers, ) set_seed(variant['seed']) # create multi-task environment and sample tasks env = env_producer(variant['domain'], seed=0) obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) reward_dim = 1 # instantiate networks latent_dim = variant['latent_size'] context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant[ 'algo_params'][ 'use_next_obs_in_context'] else obs_dim + action_dim + reward_dim context_encoder_output_dim = latent_dim * 2 if variant['algo_params'][ 'use_information_bottleneck'] else latent_dim net_size = variant['net_size'] recurrent = variant['algo_params']['recurrent'] encoder_model = RecurrentEncoder if recurrent else MlpEncoder context_encoder = encoder_model( hidden_sizes=[200, 200, 200], input_size=context_encoder_input_dim, output_size=context_encoder_output_dim, ) qf1 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim, output_size=1, ) qf2 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim, output_size=1, ) vf = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + latent_dim, output_size=1, ) policy = TanhGaussianPolicy( hidden_sizes=[net_size, net_size, net_size], obs_dim=obs_dim + latent_dim, latent_dim=latent_dim, action_dim=action_dim, ) agent = PEARLAgent(latent_dim, context_encoder, policy, **variant['algo_params']) algorithm = PEARLSoftActorCritic(env=env, train_goals=train_goals, wd_goals=wd_goals, ood_goals=ood_goals, replay_buffers=train_buffer, nets=[agent, qf1, qf2, vf], latent_dim=latent_dim, **variant['algo_params']) # optionally load pre-trained weights if variant['path_to_weights'] is not None: path = variant['path_to_weights'] context_encoder.load_state_dict( torch.load(os.path.join(path, 'context_encoder.pth'))) qf1.load_state_dict(torch.load(os.path.join(path, 'qf1.pth'))) qf2.load_state_dict(torch.load(os.path.join(path, 'qf2.pth'))) vf.load_state_dict(torch.load(os.path.join(path, 'vf.pth'))) # TODO hacky, revisit after model refactor algorithm.networks[-2].load_state_dict( torch.load(os.path.join(path, 'target_vf.pth'))) policy.load_state_dict(torch.load(os.path.join(path, 'policy.pth'))) # optional GPU mode ptu.set_gpu_mode(variant['util_params']['use_gpu'], variant['util_params']['gpu_id']) if ptu.gpu_enabled(): algorithm.to() # debugging triggers a lot of printing and logs to a debug directory DEBUG = variant['util_params']['debug'] os.environ['DEBUG'] = str(int(DEBUG)) # create logging directory # TODO support Docker exp_id = 'debug' if DEBUG else None experiment_log_dir = setup_logger( variant['domain'], variant=variant, exp_id=exp_id, base_log_dir=variant['util_params']['base_log_dir']) # optionally save eval trajectories as pkl files if variant['algo_params']['dump_eval_paths']: pickle_dir = experiment_log_dir + '/eval_trajectories' pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True) # run the algorithm algorithm.train()
def experiment(variant): print (variant['env_name']) print (variant['env_params']) env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params'])) tasks = env.get_all_task_idx() obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) cont_latent_dim, num_cat, latent_dim, num_dir, dir_latent_dim = read_dim(variant['global_latent']) r_cont_dim, r_n_cat, r_cat_dim, r_n_dir, r_dir_dim = read_dim(variant['vrnn_latent']) reward_dim = 1 net_size = variant['net_size'] recurrent = variant['algo_params']['recurrent'] glob = variant['algo_params']['glob'] rnn = variant['rnn'] vrnn_latent = variant['vrnn_latent'] encoder_model = MlpEncoder if recurrent: if variant['vrnn_constraint'] == 'logitnormal': output_size = r_cont_dim * 2 + r_n_cat * r_cat_dim + r_n_dir * r_dir_dim * 2 else: output_size = r_cont_dim * 2 + r_n_cat * r_cat_dim + r_n_dir * r_dir_dim if variant['rnn_sample'] == 'batch_sampling': if variant['algo_params']['use_next_obs']: input_size = (2 * obs_dim + action_dim + reward_dim) * variant['temp_res'] else: input_size = (obs_dim + action_dim + reward_dim) * variant['temp_res'] else: if variant['algo_params']['use_next_obs']: input_size = (2 * obs_dim + action_dim + reward_dim) else: input_size = (obs_dim + action_dim + reward_dim) if rnn == 'rnn': recurrent_model = RecurrentEncoder recurrent_context_encoder = recurrent_model( hidden_sizes=[net_size, net_size, net_size], input_size=input_size, output_size = output_size ) elif rnn == 'vrnn': recurrent_model = VRNNEncoder recurrent_context_encoder = recurrent_model( hidden_sizes=[net_size, net_size, net_size], input_size=input_size, output_size=output_size, temperature=variant['temperature'], vrnn_latent=variant['vrnn_latent'], vrnn_constraint=variant['vrnn_constraint'], r_alpha=variant['vrnn_alpha'], r_var=variant['vrnn_var'], ) else: recurrent_context_encoder = None ptu.set_gpu_mode(variant['util_params']['use_gpu'], variant['util_params']['gpu_id']) if glob: if dir_latent_dim > 0 and variant['constraint'] == 'logitnormal': output_size = cont_latent_dim * 2 + num_cat * latent_dim + num_dir * dir_latent_dim * 2 else: output_size = cont_latent_dim * 2 + num_cat * latent_dim + num_dir * dir_latent_dim if variant['algo_params']['use_next_obs']: input_size = 2 * obs_dim + action_dim + reward_dim else: input_size = obs_dim + action_dim + reward_dim global_context_encoder = encoder_model( hidden_sizes=[net_size, net_size, net_size], input_size=input_size, output_size=output_size, ) else: global_context_encoder = None qf1 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \ + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim, output_size=1, ) qf2 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \ + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim, output_size=1, ) target_qf1 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \ + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim, output_size=1, ) target_qf2 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \ + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim, output_size=1, ) policy = TanhGaussianPolicy( hidden_sizes=[net_size, net_size, net_size], obs_dim=obs_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \ + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim, latent_dim=latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \ + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim, action_dim=action_dim, ) agent = PEARLAgent( global_context_encoder, recurrent_context_encoder, variant['global_latent'], variant['vrnn_latent'], policy, variant['temperature'], variant['unitkl'], variant['alpha'], variant['constraint'], variant['vrnn_constraint'], variant['var'], variant['vrnn_alpha'], variant['vrnn_var'], rnn, variant['temp_res'], variant['rnn_sample'], variant['weighted_sample'], **variant['algo_params'] ) if variant['path_to_weights'] is not None: path = variant['path_to_weights'] with open(os.path.join(path, 'extra_data.pkl'), 'rb') as f: extra_data = pickle.load(f) variant['algo_params']['start_epoch'] = extra_data['epoch'] + 1 replay_buffer = extra_data['replay_buffer'] enc_replay_buffer = extra_data['enc_replay_buffer'] variant['algo_params']['_n_train_steps_total'] = extra_data['_n_train_steps_total'] variant['algo_params']['_n_env_steps_total'] = extra_data['_n_env_steps_total'] variant['algo_params']['_n_rollouts_total'] = extra_data['_n_rollouts_total'] else: replay_buffer=None enc_replay_buffer=None algorithm = PEARLSoftActorCritic( env=env, train_tasks=list(tasks[:variant['n_train_tasks']]), eval_tasks=list(tasks[-variant['n_eval_tasks']:]), nets=[agent, qf1, qf2, target_qf1, target_qf2], latent_dim=latent_dim, replay_buffer=replay_buffer, enc_replay_buffer=enc_replay_buffer, temp_res=variant['temp_res'], rnn_sample=variant['rnn_sample'], **variant['algo_params'] ) if variant['path_to_weights'] is not None: path = variant['path_to_weights'] if recurrent_context_encoder != None: recurrent_context_encoder.load_state_dict(torch.load(os.path.join(path, 'recurrent_context_encoder.pth'))) if global_context_encoder != None: global_context_encoder.load_state_dict(torch.load(os.path.join(path, 'global_context_encoder.pth'))) qf1.load_state_dict(torch.load(os.path.join(path, 'qf1.pth'))) qf2.load_state_dict(torch.load(os.path.join(path, 'qf2.pth'))) target_qf1.load_state_dict(torch.load(os.path.join(path, 'target_qf1.pth'))) target_qf2.load_state_dict(torch.load(os.path.join(path, 'target_qf2.pth'))) policy.load_state_dict(torch.load(os.path.join(path, 'policy.pth'))) if ptu.gpu_enabled(): algorithm.to() DEBUG = variant['util_params']['debug'] os.environ['DEBUG'] = str(int(DEBUG)) exp_id = 'debug' if DEBUG else None if variant.get('log_name', "") == "": log_name = variant['env_name'] else: log_name = variant['log_name'] experiment_log_dir = setup_logger(log_name, \ variant=variant, \ exp_id=exp_id, \ base_log_dir=variant['util_params']['base_log_dir'], \ config_log_dir=variant['util_params']['config_log_dir'], \ log_dir=variant['util_params']['log_dir']) if variant['algo_params']['dump_eval_paths']: pickle_dir = experiment_log_dir + '/eval_trajectories' pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True) env.save_all_tasks(experiment_log_dir) if variant['eval']: algorithm._try_to_eval(0, eval_all=True, eval_train_offline=False, animated=True) else: algorithm.train()
def setup_and_run(variant): ptu.set_gpu_mode(variant['util_params']['use_gpu'], variant['seed'] % variant['util_params']['num_gpus']) #setup env env_name = variant['env_name'] env_params = variant['env_params'] env_params['n_tasks'] = variant["n_train_tasks"] + variant["n_eval_tasks"] env = NormalizedBoxEnv(ENVS[env_name](**env_params)) obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) latent_dim = variant['latent_size'] reward_dim = 1 #setup encoder context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant[ 'algo_params'][ 'use_next_obs_in_context'] else obs_dim + action_dim + reward_dim context_encoder_output_dim = latent_dim * 2 if variant['algo_params'][ 'use_information_bottleneck'] else latent_dim net_size = variant['net_size'] recurrent = variant['algo_params']['recurrent'] encoder_model = RecurrentEncoder if recurrent else MlpEncoder context_encoder = encoder_model( hidden_sizes=[200, 200, 200], input_size=context_encoder_input_dim, output_size=context_encoder_output_dim, ) #setup actor, critic qf1 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim, output_size=1, ) qf2 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim, output_size=1, ) target_qf1 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim, output_size=1, ) target_qf2 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim, output_size=1, ) policy = TanhGaussianPolicy( hidden_sizes=[net_size, net_size, net_size], obs_dim=obs_dim + latent_dim, latent_dim=latent_dim, action_dim=action_dim, ) agent = PEARLAgent(latent_dim, context_encoder, policy, **variant['algo_params']) algorithm = PEARLSoftActorCritic( env=env, train_tasks=list(np.arange(variant['n_train_tasks'])), eval_tasks=list( np.arange(variant['n_train_tasks'], variant['n_train_tasks'] + variant['n_eval_tasks'])), nets=[agent, qf1, qf2, target_qf1, target_qf2], latent_dim=latent_dim, **variant['algo_params']) # optionally load pre-trained weights if variant['path_to_weights'] is not None: path = variant['path_to_weights'] context_encoder.load_state_dict( torch.load(os.path.join(path, 'context_encoder.pth'))) qf1.load_state_dict(torch.load(os.path.join(path, 'qf1.pth'))) qf2.load_state_dict(torch.load(os.path.join(path, 'qf2.pth'))) target_qf1.load_state_dict( torch.load(os.path.join(path, 'target_qf1.pth'))) target_qf2.load_state_dict( torch.load(os.path.join(path, 'target_qf2.pth'))) # TODO hacky, revisit after model refactor policy.load_state_dict(torch.load(os.path.join(path, 'policy.pth'))) if ptu.gpu_enabled(): algorithm.to() os.environ['DEBUG'] = str(int(variant['util_params']['debug'])) #setup logger run_mode = variant['run_mode'] exp_log_name = os.path.join( variant['env_name'], run_mode, variant['log_annotation'] + variant['variant_name'], 'seed-' + str(variant['seed'])) setup_logger(exp_log_name, variant=variant, exp_id=None, base_log_dir=os.environ.get('PEARL_DATA_PATH'), snapshot_mode='gap', snapshot_gap=10) # run the algorithm if run_mode == 'TRAIN': algorithm.train() elif run_mode == 'EVAL': assert variant['algo_params']['dump_eval_paths'] == True algorithm._try_to_eval() else: algorithm.eval_with_loaded_latent()
def experiment(variant): eval_env = gym.make(variant['env_name']) expl_env = eval_env obs_dim = expl_env.observation_space.low.size action_dim = eval_env.action_space.low.size M = variant['layer_size'] # q and policy netwroks qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ).to(ptu.device) qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ).to(ptu.device) target_qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ).to(ptu.device) target_qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ).to(ptu.device) # initialize with bc or not if variant['bc_model'] is None: policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=[M, M], ).to(ptu.device) else: bc_model = Mlp( input_size=obs_dim, output_size=action_dim, hidden_sizes=[64, 64], output_activation=F.tanh, ).to(ptu.device) checkpoint = torch.load(variant['bc_model'], map_location=map_location) bc_model.load_state_dict(checkpoint['network_state_dict']) print('Loading bc model: {}'.format(variant['bc_model'])) # policy initialized with bc policy = TanhGaussianPolicy_BC( obs_dim=obs_dim, action_dim=action_dim, mean_network=bc_model, hidden_sizes=[M, M], ).to(ptu.device) # if bonus: define bonus networks if not variant['offline']: bonus_layer_size = variant['bonus_layer_size'] bonus_network = Mlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[bonus_layer_size, bonus_layer_size], output_activation=F.sigmoid, ).to(ptu.device) checkpoint = torch.load(variant['bonus_path'], map_location=map_location) bonus_network.load_state_dict(checkpoint['network_state_dict']) print('Loading bonus model: {}'.format(variant['bonus_path'])) if variant['initialize_Q'] and bonus_layer_size == M: target_qf1.load_state_dict(checkpoint['network_state_dict']) target_qf2.load_state_dict(checkpoint['network_state_dict']) print('Initialize QF1 and QF2 with the bonus model: {}'.format( variant['bonus_path'])) if variant['initialize_Q'] and bonus_layer_size != M: print( ' Size mismatch between Q and bonus- Turining off the initialization' ) # eval_policy = MakeDeterministic(policy) eval_path_collector = CustomMDPPathCollector(eval_env, ) expl_path_collector = MdpPathCollector( expl_env, policy, ) buffer_filename = None if variant['buffer_filename'] is not None: buffer_filename = variant['buffer_filename'] replay_buffer = EnvReplayBuffer( variant['replay_buffer_size'], expl_env, ) dataset = eval_env.unwrapped.get_dataset() load_hdf5(dataset, replay_buffer, max_size=variant['replay_buffer_size']) if variant['normalize']: obs_mu, obs_std = dataset['observations'].mean( axis=0), dataset['observations'].std(axis=0) bonus_norm_param = [obs_mu, obs_std] else: bonus_norm_param = [None] * 2 # shift the reward if variant['reward_shift'] is not None: rewards_shift_param = min(dataset['rewards']) - variant['reward_shift'] print('.... reward is shifted : {} '.format(rewards_shift_param)) else: rewards_shift_param = None if variant['offline']: trainer = SACTrainer(env=eval_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, rewards_shift_param=rewards_shift_param, **variant['trainer_kwargs']) print('Agent of type offline SAC created') elif variant['bonus'] == 'bonus_add': trainer = SAC_BonusTrainer( env=eval_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, bonus_network=bonus_network, beta=variant['bonus_beta'], use_bonus_critic=variant['use_bonus_critic'], use_bonus_policy=variant['use_bonus_policy'], use_log=variant['use_log'], bonus_norm_param=bonus_norm_param, rewards_shift_param=rewards_shift_param, device=ptu.device, **variant['trainer_kwargs']) print('Agent of type SAC + additive bonus created') elif variant['bonus'] == 'bonus_mlt': trainer = SAC_BonusTrainer_Mlt( env=eval_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, bonus_network=bonus_network, beta=variant['bonus_beta'], use_bonus_critic=variant['use_bonus_critic'], use_bonus_policy=variant['use_bonus_policy'], bonus_norm_param=bonus_norm_param, rewards_shift_param=rewards_shift_param, device=ptu.device, **variant['trainer_kwargs']) print('Agent of type SAC + multiplicative bonus created') else: raise ValueError('Not implemented error') algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, batch_rl=True, q_learning_alg=True, **variant['algorithm_kwargs']) algorithm.to(ptu.device) algorithm.train()
def main( env_name, seed, deterministic, traj_prior, start_ft_after, ft_steps, avoid_freezing_z, lr, batch_size, avoid_loading_critics ): config = "configs/{}.json".format(env_name) variant = default_config if config: with open(osp.join(config)) as f: exp_params = json.load(f) variant = deep_update_dict(exp_params, variant) exp_name = variant['env_name'] print("Experiment: {}".format(exp_name)) env = NormalizedBoxEnv(ENVS[exp_name](**variant['env_params'])) obs_dim = int(np.prod(env.observation_space.shape)) action_dim = int(np.prod(env.action_space.shape)) print("Observation space:") print(env.observation_space) print(obs_dim) print("Action space:") print(env.action_space) print(action_dim) print("-" * 10) # instantiate networks latent_dim = variant['latent_size'] reward_dim = 1 context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant['algo_params']['use_next_obs_in_context'] else obs_dim + action_dim + reward_dim context_encoder_output_dim = latent_dim * 2 if variant['algo_params']['use_information_bottleneck'] else latent_dim net_size = variant['net_size'] recurrent = variant['algo_params']['recurrent'] encoder_model = RecurrentEncoder if recurrent else MlpEncoder context_encoder = encoder_model( hidden_sizes=[200, 200, 200], input_size=context_encoder_input_dim, output_size=context_encoder_output_dim, ) qf1 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim, output_size=1, ) qf2 = FlattenMlp( hidden_sizes=[net_size, net_size, net_size], input_size=obs_dim + action_dim + latent_dim, output_size=1, ) target_qf1 = qf1.copy() target_qf2 = qf2.copy() policy = TanhGaussianPolicy( hidden_sizes=[net_size, net_size, net_size], obs_dim=obs_dim + latent_dim, latent_dim=latent_dim, action_dim=action_dim, ) agent = PEARLAgent( latent_dim, context_encoder, policy, **variant['algo_params'] ) # deterministic eval if deterministic: agent = MakeDeterministic(agent) # load trained weights (otherwise simulate random policy) path_to_exp = "output/{}/pearl_{}".format(env_name, seed-1) print("Based on experiment: {}".format(path_to_exp)) context_encoder.load_state_dict(torch.load(os.path.join(path_to_exp, 'context_encoder.pth'))) policy.load_state_dict(torch.load(os.path.join(path_to_exp, 'policy.pth'))) if not avoid_loading_critics: qf1.load_state_dict(torch.load(os.path.join(path_to_exp, 'qf1.pth'))) qf2.load_state_dict(torch.load(os.path.join(path_to_exp, 'qf2.pth'))) target_qf1.load_state_dict(torch.load(os.path.join(path_to_exp, 'target_qf1.pth'))) target_qf2.load_state_dict(torch.load(os.path.join(path_to_exp, 'target_qf2.pth'))) # optional GPU mode ptu.set_gpu_mode(variant['util_params']['use_gpu'], variant['util_params']['gpu_id']) if ptu.gpu_enabled(): agent.to(device) policy.to(device) context_encoder.to(device) qf1.to(device) qf2.to(device) target_qf1.to(device) target_qf2.to(device) helper = PEARLFineTuningHelper( env=env, agent=agent, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, num_exp_traj_eval=traj_prior, start_fine_tuning=start_ft_after, fine_tuning_steps=ft_steps, should_freeze_z=(not avoid_freezing_z), replay_buffer_size=int(1e6), batch_size=batch_size, discount=0.99, policy_lr=lr, qf_lr=lr, temp_lr=lr, target_entropy=-action_dim, ) helper.fine_tune(variant=variant, seed=seed)