def __init__(self, config, args, threshold=0.3): self.bs_obj = CREST(threshold=threshold) self.config = config validation_games = 20 teacher_path = config['general']['teacher_model_path'] print('Setting up TextWorld environment...') self.batch_size = 1 # load print('Making env id {}'.format(config['general']['env_id'])) env_id = gym_textworld.make_batch(env_id=config['general']['env_id'], batch_size=self.batch_size, parallel=True) self.env = gym.make(env_id) # self.env.seed(config['general']['random_seed']) test_batch_size = config['training']['scheduling']['test_batch_size'] # valid valid_env_name = config['general']['valid_env_id'] valid_env_id = gym_textworld.make_batch(env_id=valid_env_name, batch_size=test_batch_size, parallel=True) self.valid_env = gym.make(valid_env_id) self.valid_env.seed(config['general']['random_seed']) self.teacher_agent = get_agent(config, self.env) print('Loading teacher from : ', teacher_path) self.teacher_agent.model.load_state_dict(torch.load(teacher_path)) # import time; time.sleep(5) self.hidden_size = config['model']['lstm_dqn']['action_scorer_hidden_dim'] self.hash_features = {}
def train(config, prune=False, embed='cnet'): # train env print('Setting up TextWorld environment...') batch_size = config['training']['scheduling']['batch_size'] env_id = gym_textworld.make_batch(env_id=config['general']['env_id'], batch_size=batch_size, parallel=True) env = gym.make(env_id) env.seed(config['general']['random_seed']) print("##" * 30) if prune: print('Using state pruning ...') else: print('Not using state pruning ...') print("##" * 30) # valid and test env run_test = config['general']['run_test'] if run_test: test_batch_size = config['training']['scheduling']['test_batch_size'] # valid valid_env_name = config['general']['valid_env_id'] valid_env_id = gym_textworld.make_batch(env_id=valid_env_name, batch_size=test_batch_size, parallel=True) valid_env = gym.make(valid_env_id) valid_env.seed(config['general']['random_seed']) # valid_env.reset() # test test_env_name_list = config['general']['test_env_id'] assert isinstance(test_env_name_list, list) test_env_id_list = [gym_textworld.make_batch(env_id=item, batch_size=test_batch_size, parallel=True) for item in test_env_name_list] test_env_list = [gym.make(test_env_id) for test_env_id in test_env_id_list] for i in range(len(test_env_list)): test_env_list[i].seed(config['general']['random_seed']) # test_env_list[i].reset() print('Done.') # Set the random seed manually for reproducibility. np.random.seed(config['general']['random_seed']) torch.manual_seed(config['general']['random_seed']) if torch.cuda.is_available(): if not config['general']['use_cuda']: logger.warning("WARNING: CUDA device detected but 'use_cuda: false' found in config.yaml") else: torch.backends.cudnn.deterministic = True torch.cuda.manual_seed(config['general']['random_seed']) else: config['general']['use_cuda'] = False # Disable CUDA. use_cuda = config['general']['use_cuda'] revisit_counting = config['general']['revisit_counting'] replay_batch_size = config['general']['replay_batch_size'] history_size = config['general']['history_size'] update_from = config['general']['update_from'] replay_memory_capacity = config['general']['replay_memory_capacity'] replay_memory_priority_fraction = config['general']['replay_memory_priority_fraction'] word_vocab = dict2list(env.observation_space.id2w) word2id = {} for i, w in enumerate(word_vocab): word2id[w] = i if config['general']['exp_act']: print('##' * 30) print('Using expanded verb list') verb_list = read_file("data/vocabs/trial_run_custom_tw/verb_vocab.txt") object_name_list = read_file("data/vocabs/common_nouns.txt") else: #"This option only works for coin collector" verb_list = ["go", "take", "unlock", "lock", "drop", "look", "insert", "open", "inventory", "close"] object_name_list = ["east", "west", "north", "south", "coin", "apple", "carrot", "textbook", "passkey", "keycard"] # Add missing words in word2id for w in verb_list: if w not in word2id.keys(): word2id[w] = len(word2id) word_vocab += [w, ] for w in object_name_list: if w not in word2id.keys(): word2id[w] = len(word2id) word_vocab += [w, ] verb_map = [word2id[w] for w in verb_list if w in word2id] noun_map = [word2id[w] for w in object_name_list if w in word2id] # teacher_path = config['general']['teacher_model_path'] # teacher_agent = Agent(config, word_vocab, verb_map, noun_map, # att=config['general']['use_attention'], # bootstrap=False, # replay_memory_capacity=replay_memory_capacity, # replay_memory_priority_fraction=replay_memory_priority_fraction) # teacher_agent.model.load_state_dict(torch.load(teacher_path)) # teacher_agent.model.eval() student_agent = Agent(config, word_vocab, verb_map, noun_map, att=config['general']['use_attention'], bootstrap=config['general']['student'], replay_memory_capacity=replay_memory_capacity, replay_memory_priority_fraction=replay_memory_priority_fraction, embed=embed) init_learning_rate = config['training']['optimizer']['learning_rate'] exp_dir = get_experiment_dir(config) summary = SummaryWriter(exp_dir) parameters = filter(lambda p: p.requires_grad, student_agent.model.parameters()) if config['training']['optimizer']['step_rule'] == 'sgd': optimizer = torch.optim.SGD(parameters, lr=init_learning_rate) elif config['training']['optimizer']['step_rule'] == 'adam': optimizer = torch.optim.Adam(parameters, lr=init_learning_rate) log_every = 100 reward_avg = SlidingAverage('reward avg', steps=log_every) step_avg = SlidingAverage('step avg', steps=log_every) loss_avg = SlidingAverage('loss avg', steps=log_every) # save & reload checkpoint only in 0th agent best_avg_reward = -10000 best_avg_step = 10000 # step penalty discount_gamma = config['general']['discount_gamma'] provide_prev_action = config['general']['provide_prev_action'] # epsilon greedy epsilon_anneal_epochs = config['general']['epsilon_anneal_epochs'] epsilon_anneal_from = config['general']['epsilon_anneal_from'] epsilon_anneal_to = config['general']['epsilon_anneal_to'] # counting reward revisit_counting_lambda_anneal_epochs = config['general']['revisit_counting_lambda_anneal_epochs'] revisit_counting_lambda_anneal_from = config['general']['revisit_counting_lambda_anneal_from'] revisit_counting_lambda_anneal_to = config['general']['revisit_counting_lambda_anneal_to'] model_checkpoint_path = config['training']['scheduling']['model_checkpoint_path'] epsilon = epsilon_anneal_from revisit_counting_lambda = revisit_counting_lambda_anneal_from ####################################################################### ##### Load the teacher data ##### ####################################################################### prefix_name = get_prefix(args) filename = './data/teacher_data/{}.npz'.format(prefix_name) teacher_dict = np.load(filename, allow_pickle=True) # import ipdb; ipdb.set_trace() global_action_set = set() print("##" * 30) print("Training for {} epochs".format(config['training']['scheduling']['epoch'])) print("##" * 30) import time t0 = time.time() for epoch in range(config['training']['scheduling']['epoch']): student_agent.model.train() obs, infos = env.reset() student_agent.reset(infos) # this the string identifier for leading the episodic action distribution id_string = student_agent.get_observation_strings(infos) cont_flag=False for id_ in id_string: if id_ not in teacher_dict.keys(): cont_flag=True if cont_flag: print('Skipping this epoch/.....') continue # Episodic action list action_dist = [teacher_dict[id_string[k]][-1] for k in range(len(id_string))] action_dist = [[x for x in item.keys()] for item in action_dist] for item in action_dist: global_action_set.update(item) print_command_string, print_rewards = [[] for _ in infos], [[] for _ in infos] print_interm_rewards = [[] for _ in infos] print_rc_rewards = [[] for _ in infos] dones = [False] * batch_size rewards = None avg_loss_in_this_game = [] curr_observation_strings = student_agent.get_observation_strings(infos) if revisit_counting: student_agent.reset_binarized_counter(batch_size) revisit_counting_rewards = student_agent.get_binarized_count(curr_observation_strings) current_game_step = 0 prev_actions = ["" for _ in range(batch_size)] if provide_prev_action else None input_description, description_id_list, student_desc, _ =\ student_agent.get_game_step_info(obs, infos, prev_actions, prune=prune, teacher_actions=action_dist, ret_desc=True,) curr_ras_hidden, curr_ras_cell = None, None # ras: recurrent action scorer memory_cache = [[] for _ in range(batch_size)] solved = [0 for _ in range(batch_size)] while not all(dones): student_agent.model.train() v_idx, n_idx, chosen_strings, curr_ras_hidden, curr_ras_cell = \ student_agent.generate_one_command(input_description, curr_ras_hidden, curr_ras_cell, epsilon=0.0, return_att=args.use_attention) obs, rewards, dones, infos = env.step(chosen_strings) curr_observation_strings = student_agent.get_observation_strings(infos) # print(chosen_strings) if provide_prev_action: prev_actions = chosen_strings # counting if revisit_counting: revisit_counting_rewards = student_agent.get_binarized_count(curr_observation_strings, update=True) else: revisit_counting_rewards = [0.0 for b in range(batch_size)] student_agent.revisit_counting_rewards.append(revisit_counting_rewards) revisit_counting_rewards = [float(format(item, ".3f")) for item in revisit_counting_rewards] for i in range(len(infos)): print_command_string[i].append(chosen_strings[i]) print_rewards[i].append(rewards[i]) print_interm_rewards[i].append(infos[i]["intermediate_reward"]) print_rc_rewards[i].append(revisit_counting_rewards[i]) if type(dones) is bool: dones = [dones] * batch_size student_agent.rewards.append(rewards) student_agent.dones.append(dones) student_agent.intermediate_rewards.append([info["intermediate_reward"] for info in infos]) # computer rewards, and push into replay memory rewards_np, rewards_pt, mask_np,\ mask_pt, memory_mask = student_agent.compute_reward(revisit_counting_lambda=revisit_counting_lambda, revisit_counting=revisit_counting) ############################### ##### Pruned state desc ##### ############################### curr_description_id_list = description_id_list input_description, description_id_list, student_desc, _ =\ student_agent.get_game_step_info(obs, infos, prev_actions, prune=prune, teacher_actions=action_dist, ret_desc=True,) for b in range(batch_size): if memory_mask[b] == 0: continue if dones[b] == 1 and rewards[b] == 0: # last possible step is_final = True else: is_final = mask_np[b] == 0 if rewards[b] > 0.0: solved[b] = 1 # replay memory memory_cache[b].append( (curr_description_id_list[b], v_idx[b], n_idx[b], rewards_pt[b], mask_pt[b], dones[b], is_final, curr_observation_strings[b])) if current_game_step > 0 and current_game_step % config["general"]["update_per_k_game_steps"] == 0: policy_loss = student_agent.update(replay_batch_size, history_size, update_from, discount_gamma=discount_gamma) if policy_loss is None: continue loss = policy_loss # Backpropagate optimizer.zero_grad() loss.backward(retain_graph=True) # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(student_agent.model.parameters(), config['training']['optimizer']['clip_grad_norm']) optimizer.step() # apply gradients avg_loss_in_this_game.append(to_np(policy_loss)) current_game_step += 1 for i, mc in enumerate(memory_cache): for item in mc: if replay_memory_priority_fraction == 0.0: # vanilla replay memory student_agent.replay_memory.push(*item) else: # prioritized replay memory student_agent.replay_memory.push(solved[i], *item) student_agent.finish() avg_loss_in_this_game = np.mean(avg_loss_in_this_game) reward_avg.add(student_agent.final_rewards.mean()) step_avg.add(student_agent.step_used_before_done.mean()) loss_avg.add(avg_loss_in_this_game) # annealing if epoch < epsilon_anneal_epochs: epsilon -= (epsilon_anneal_from - epsilon_anneal_to) / float(epsilon_anneal_epochs) if epoch < revisit_counting_lambda_anneal_epochs: revisit_counting_lambda -= (revisit_counting_lambda_anneal_from - revisit_counting_lambda_anneal_to) / float(revisit_counting_lambda_anneal_epochs) # Tensorboard logging # # (1) Log some numbers if (epoch + 1) % config["training"]["scheduling"]["logging_frequency"] == 0: summary.add_scalar('avg_reward', reward_avg.value, epoch + 1) summary.add_scalar('curr_reward', student_agent.final_rewards.mean(), epoch + 1) summary.add_scalar('curr_interm_reward', student_agent.final_intermediate_rewards.mean(), epoch + 1) summary.add_scalar('curr_counting_reward', student_agent.final_counting_rewards.mean(), epoch + 1) summary.add_scalar('avg_step', step_avg.value, epoch + 1) summary.add_scalar('curr_step', student_agent.step_used_before_done.mean(), epoch + 1) summary.add_scalar('loss_avg', loss_avg.value, epoch + 1) summary.add_scalar('curr_loss', avg_loss_in_this_game, epoch + 1) t1 = time.time() summary.add_scalar('time', t1 - t0, epoch + 1) msg = 'E#{:03d}, R={:.3f}/{:.3f}/IR{:.3f}/CR{:.3f}, S={:.3f}/{:.3f}, L={:.3f}/{:.3f}, epsilon={:.4f}, lambda_counting={:.4f}' msg = msg.format(epoch, np.mean(reward_avg.value), student_agent.final_rewards.mean(), student_agent.final_intermediate_rewards.mean(), student_agent.final_counting_rewards.mean(), np.mean(step_avg.value), student_agent.step_used_before_done.mean(), np.mean(loss_avg.value), avg_loss_in_this_game, epsilon, revisit_counting_lambda) if (epoch + 1) % config["training"]["scheduling"]["logging_frequency"] == 0: torch.save(student_agent.model.state_dict(), model_checkpoint_path.replace('.pt', '_train.pt')) print("=========================================================") for prt_cmd, prt_rew, prt_int_rew, prt_rc_rew in zip(print_command_string, print_rewards, print_interm_rewards, print_rc_rewards): print("------------------------------") print(prt_cmd) print(prt_rew) print(prt_int_rew) print(prt_rc_rew) print(msg) # test on a different set of games if run_test and (epoch) % config["training"]["scheduling"]["logging_frequency"] == 0: valid_R, valid_IR, valid_S = test(config, valid_env, student_agent, test_batch_size, word2id, prune=prune, teacher_actions=[list(global_action_set)]*test_batch_size) summary.add_scalar('valid_reward', valid_R, epoch + 1) summary.add_scalar('valid_interm_reward', valid_IR, epoch + 1) summary.add_scalar('valid_step', valid_S, epoch + 1) # save & reload checkpoint by best valid performance if valid_R > best_avg_reward or (valid_R == best_avg_reward and valid_S < best_avg_step): best_avg_reward = valid_R best_avg_step = valid_S torch.save(student_agent.model.state_dict(), model_checkpoint_path.replace('.pt', '_best.pt')) print("========= saved checkpoint =========")
def train(config): # train env print('Setting up TextWorld environment...') batch_size = config['training']['scheduling']['batch_size'] env_id = gym_textworld.make_batch(env_id=config['general']['env_id'], batch_size=batch_size, parallel=True) env = gym.make(env_id) env.seed(config['general']['random_seed']) # valid and test env run_test = config['general']['run_test'] if run_test: test_batch_size = config['training']['scheduling']['test_batch_size'] # valid valid_env_name = config['general']['valid_env_id'] valid_env_id = gym_textworld.make_batch(env_id=valid_env_name, batch_size=test_batch_size, parallel=True) valid_env = gym.make(valid_env_id) valid_env.seed(config['general']['random_seed']) # test test_env_name_list = config['general']['test_env_id'] assert isinstance(test_env_name_list, list) test_env_id_list = [ gym_textworld.make_batch(env_id=item, batch_size=test_batch_size, parallel=True) for item in test_env_name_list ] test_env_list = [ gym.make(test_env_id) for test_env_id in test_env_id_list ] for i in range(len(test_env_list)): test_env_list[i].seed(config['general']['random_seed']) print('Done.') # Set the random seed manually for reproducibility. np.random.seed(config['general']['random_seed']) torch.manual_seed(config['general']['random_seed']) if torch.cuda.is_available(): if not config['general']['use_cuda']: logger.warning( "WARNING: CUDA device detected but 'use_cuda: false' found in config.yaml" ) else: torch.backends.cudnn.deterministic = True torch.cuda.manual_seed(config['general']['random_seed']) else: config['general']['use_cuda'] = False # Disable CUDA. revisit_counting = config['general']['revisit_counting'] replay_batch_size = config['general']['replay_batch_size'] replay_memory_capacity = config['general']['replay_memory_capacity'] replay_memory_priority_fraction = config['general'][ 'replay_memory_priority_fraction'] word_vocab = dict2list(env.observation_space.id2w) word2id = {} for i, w in enumerate(word_vocab): word2id[w] = i # collect all nouns verb_list = ["go", "take"] object_name_list = ["east", "west", "north", "south", "coin"] verb_map = [word2id[w] for w in verb_list if w in word2id] noun_map = [word2id[w] for w in object_name_list if w in word2id] agent = RLAgent( config, word_vocab, verb_map, noun_map, replay_memory_capacity=replay_memory_capacity, replay_memory_priority_fraction=replay_memory_priority_fraction) init_learning_rate = config['training']['optimizer']['learning_rate'] exp_dir = get_experiment_dir(config) summary = SummaryWriter(exp_dir) parameters = filter(lambda p: p.requires_grad, agent.model.parameters()) if config['training']['optimizer']['step_rule'] == 'sgd': optimizer = torch.optim.SGD(parameters, lr=init_learning_rate) elif config['training']['optimizer']['step_rule'] == 'adam': optimizer = torch.optim.Adam(parameters, lr=init_learning_rate) log_every = 100 reward_avg = SlidingAverage('reward avg', steps=log_every) step_avg = SlidingAverage('step avg', steps=log_every) loss_avg = SlidingAverage('loss avg', steps=log_every) # save & reload checkpoint only in 0th agent best_avg_reward = -10000 best_avg_step = 10000 # step penalty discount_gamma = config['general']['discount_gamma'] provide_prev_action = config['general']['provide_prev_action'] # epsilon greedy epsilon_anneal_epochs = config['general']['epsilon_anneal_epochs'] epsilon_anneal_from = config['general']['epsilon_anneal_from'] epsilon_anneal_to = config['general']['epsilon_anneal_to'] # counting reward revisit_counting_lambda_anneal_epochs = config['general'][ 'revisit_counting_lambda_anneal_epochs'] revisit_counting_lambda_anneal_from = config['general'][ 'revisit_counting_lambda_anneal_from'] revisit_counting_lambda_anneal_to = config['general'][ 'revisit_counting_lambda_anneal_to'] epsilon = epsilon_anneal_from revisit_counting_lambda = revisit_counting_lambda_anneal_from for epoch in range(config['training']['scheduling']['epoch']): agent.model.train() obs, infos = env.reset() agent.reset(infos) print_command_string, print_rewards = [[] for _ in infos ], [[] for _ in infos] print_interm_rewards = [[] for _ in infos] print_rc_rewards = [[] for _ in infos] dones = [False] * batch_size rewards = None avg_loss_in_this_game = [] new_observation_strings = agent.get_observation_strings(infos) if revisit_counting: agent.reset_binarized_counter(batch_size) revisit_counting_rewards = agent.get_binarized_count( new_observation_strings) current_game_step = 0 prev_actions = ["" for _ in range(batch_size) ] if provide_prev_action else None input_description, description_id_list = agent.get_game_step_info( obs, infos, prev_actions) while not all(dones): v_idx, n_idx, chosen_strings, state_representation = agent.generate_one_command( input_description, epsilon=epsilon) obs, rewards, dones, infos = env.step(chosen_strings) new_observation_strings = agent.get_observation_strings(infos) if provide_prev_action: prev_actions = chosen_strings # counting if revisit_counting: revisit_counting_rewards = agent.get_binarized_count( new_observation_strings, update=True) else: revisit_counting_rewards = [0.0 for _ in range(batch_size)] agent.revisit_counting_rewards.append(revisit_counting_rewards) revisit_counting_rewards = [ float(format(item, ".3f")) for item in revisit_counting_rewards ] for i in range(len(infos)): print_command_string[i].append(chosen_strings[i]) print_rewards[i].append(rewards[i]) print_interm_rewards[i].append(infos[i]["intermediate_reward"]) print_rc_rewards[i].append(revisit_counting_rewards[i]) if type(dones) is bool: dones = [dones] * batch_size agent.rewards.append(rewards) agent.dones.append(dones) agent.intermediate_rewards.append( [info["intermediate_reward"] for info in infos]) # computer rewards, and push into replay memory rewards_np, rewards, mask_np, mask = agent.compute_reward( revisit_counting_lambda=revisit_counting_lambda, revisit_counting=revisit_counting) curr_description_id_list = description_id_list input_description, description_id_list = agent.get_game_step_info( obs, infos, prev_actions) for b in range(batch_size): if mask_np[b] == 0: continue if replay_memory_priority_fraction == 0.0: # vanilla replay memory agent.replay_memory.push(curr_description_id_list[b], v_idx[b], n_idx[b], rewards[b], mask[b], dones[b], description_id_list[b], new_observation_strings[b]) else: # prioritized replay memory is_prior = rewards_np[b] > 0.0 agent.replay_memory.push(is_prior, curr_description_id_list[b], v_idx[b], n_idx[b], rewards[b], mask[b], dones[b], description_id_list[b], new_observation_strings[b]) if current_game_step > 0 and current_game_step % config["general"][ "update_per_k_game_steps"] == 0: policy_loss = agent.update(replay_batch_size, discount_gamma=discount_gamma) if policy_loss is None: continue loss = policy_loss # Backpropagate optimizer.zero_grad() loss.backward(retain_graph=True) # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm( agent.model.parameters(), config['training']['optimizer']['clip_grad_norm']) optimizer.step() # apply gradients avg_loss_in_this_game.append(to_np(policy_loss)) current_game_step += 1 agent.finish() avg_loss_in_this_game = np.mean(avg_loss_in_this_game) reward_avg.add(agent.final_rewards.mean()) step_avg.add(agent.step_used_before_done.mean()) loss_avg.add(avg_loss_in_this_game) # annealing if epoch < epsilon_anneal_epochs: epsilon -= (epsilon_anneal_from - epsilon_anneal_to) / float(epsilon_anneal_epochs) if epoch < revisit_counting_lambda_anneal_epochs: revisit_counting_lambda -= ( revisit_counting_lambda_anneal_from - revisit_counting_lambda_anneal_to ) / float(revisit_counting_lambda_anneal_epochs) # Tensorboard logging # # (1) Log some numbers if (epoch + 1 ) % config["training"]["scheduling"]["logging_frequency"] == 0: summary.add_scalar('avg_reward', reward_avg.value, epoch + 1) summary.add_scalar('curr_reward', agent.final_rewards.mean(), epoch + 1) summary.add_scalar('curr_interm_reward', agent.final_intermediate_rewards.mean(), epoch + 1) summary.add_scalar('curr_counting_reward', agent.final_counting_rewards.mean(), epoch + 1) summary.add_scalar('avg_step', step_avg.value, epoch + 1) summary.add_scalar('curr_step', agent.step_used_before_done.mean(), epoch + 1) summary.add_scalar('loss_avg', loss_avg.value, epoch + 1) summary.add_scalar('curr_loss', avg_loss_in_this_game, epoch + 1) msg = 'E#{:03d}, R={:.3f}/{:.3f}/IR{:.3f}/CR{:.3f}, S={:.3f}/{:.3f}, L={:.3f}/{:.3f}, epsilon={:.4f}, lambda_counting={:.4f}' msg = msg.format(epoch, np.mean(reward_avg.value), agent.final_rewards.mean(), agent.final_intermediate_rewards.mean(), agent.final_counting_rewards.mean(), np.mean(step_avg.value), agent.step_used_before_done.mean(), np.mean(loss_avg.value), avg_loss_in_this_game, epsilon, revisit_counting_lambda) if (epoch + 1 ) % config["training"]["scheduling"]["logging_frequency"] == 0: print("=========================================================") for prt_cmd, prt_rew, prt_int_rew, prt_rc_rew in zip( print_command_string, print_rewards, print_interm_rewards, print_rc_rewards): print("------------------------------") print(prt_cmd) print(prt_rew) print(prt_int_rew) print(prt_rc_rew) print(msg) # test on a different set of games if run_test and (epoch + 1) % config["training"]["scheduling"][ "logging_frequency"] == 0: valid_R, valid_IR, valid_S = test(config, valid_env, agent, test_batch_size, word2id) summary.add_scalar('valid_reward', valid_R, epoch + 1) summary.add_scalar('valid_interm_reward', valid_IR, epoch + 1) summary.add_scalar('valid_step', valid_S, epoch + 1) # save & reload checkpoint by best valid performance model_checkpoint_path = config['training']['scheduling'][ 'model_checkpoint_path'] if valid_R > best_avg_reward or (valid_R == best_avg_reward and valid_S < best_avg_step): best_avg_reward = valid_R best_avg_step = valid_S torch.save(agent.model.state_dict(), model_checkpoint_path) print("========= saved checkpoint =========") for test_id in range(len(test_env_list)): R, IR, S = test(config, test_env_list[test_id], agent, test_batch_size, word2id) summary.add_scalar('test_reward_' + str(test_id), R, epoch + 1) summary.add_scalar('test_interm_reward_' + str(test_id), IR, epoch + 1) summary.add_scalar('test_step_' + str(test_id), S, epoch + 1)
def load_valid_env(self, valid_env_name): test_batch_size = 1 valid_env_id = gym_textworld.make_batch(env_id=valid_env_name, batch_size=test_batch_size, parallel=True) self.valid_env = gym.make(valid_env_id) self.valid_env.seed(config['general']['random_seed']) print('Loaded env name: ', valid_env_name)