# Load environment constants env_tmp = RecursiveListEnv(length=5, encoding_dim=conf.encoding_dim) num_programs = env_tmp.get_num_programs() num_non_primary_programs = env_tmp.get_num_non_primary_programs() observation_dim = env_tmp.get_observation_dim() programs_library = env_tmp.programs_library # Load Alpha-NPI policy encoder = RecursiveListEnvEncoder(env_tmp.get_observation_dim(), conf.encoding_dim) indices_non_primary_programs = [ p['index'] for _, p in programs_library.items() if p['level'] > 0 ] policy = Policy(encoder, conf.hidden_size, num_programs, num_non_primary_programs, conf.program_embedding_dim, conf.encoding_dim, indices_non_primary_programs, conf.learning_rate) # Load replay buffer idx_tasks = [ prog['index'] for key, prog in env_tmp.programs_library.items() if prog['level'] > 0 ] buffer = PrioritizedReplayBuffer(conf.buffer_max_length, idx_tasks, p1=conf.proba_replay_buffer) # Prepare mcts params max_depth_dict = {1: 5, 2: 5, 3: 5} mcts_train_params = { 'number_of_simulations': conf.number_of_simulations,
results_file = open(results_save_path, 'w') # Load environment constants env_tmp = ListEnv(length=5, encoding_dim=conf.encoding_dim) num_programs = env_tmp.get_num_programs() num_non_primary_programs = env_tmp.get_num_non_primary_programs() observation_dim = env_tmp.get_observation_dim() programs_library = env_tmp.programs_library # Load Alpha-NPI policy encoder = ListEnvEncoder(env_tmp.get_observation_dim(), conf.encoding_dim) indices_non_primary_programs = [ p['index'] for _, p in programs_library.items() if p['level'] > 0 ] policy = Policy(encoder, conf.hidden_size, num_programs, num_non_primary_programs, conf.program_embedding_dim, conf.encoding_dim, indices_non_primary_programs, conf.learning_rate) policy.load_state_dict(torch.load(load_path)) # Start validation if verbose: print('Start validation for model: {}'.format(load_path)) if save_results: results_file.write('Validation on model: {}'.format(load_path) + ' \n') for len in [5, 10, 20, 60, 100]: mcts_rewards_normalized = [] mcts_rewards = []