Ejemplo n.º 1
0
 def get_chosen_strings(self, v_idx, n_idx):
     v_idx_np = to_np(v_idx)
     n_idx_np = to_np(n_idx)
     res_str = []
     for i in range(n_idx_np.shape[0]):
         v, n = self.verb_map[v_idx_np[i]], self.noun_map[n_idx_np[i]]
         res_str.append(self.word_vocab[v] + " " + self.word_vocab[n])
     return res_str
Ejemplo n.º 2
0
 def choose_maxQ_command(self, verb_rank, noun_rank):
     batch_size = verb_rank.size(0)
     vr, nr = to_np(verb_rank), to_np(noun_rank)
     v_idx = np.argmax(vr, -1)
     n_idx = np.argmax(nr, -1)
     v_qvalue, n_qvalue = [], []
     for i in range(batch_size):
         v_qvalue.append(verb_rank[i][v_idx[i]])
         n_qvalue.append(noun_rank[i][n_idx[i]])
     v_qvalue, n_qvalue = torch.cat(v_qvalue), torch.cat(n_qvalue)
     v_idx, n_idx = to_pt(v_idx, self.use_cuda), to_pt(n_idx, self.use_cuda)
     return v_qvalue, v_idx, n_qvalue, n_idx
    def choose_random_command(self, verb_rank, noun_rank):
        batch_size = verb_rank.size(0)
        vr, nr = to_np(verb_rank), to_np(noun_rank)

        v_idx, n_idx = [], []
        for i in range(batch_size):
            v_idx.append(np.random.choice(len(vr[i]), 1)[0])
            n_idx.append(np.random.choice(len(nr[i]), 1)[0])
        v_qvalue, n_qvalue = [], []
        for i in range(batch_size):
            v_qvalue.append(verb_rank[i][v_idx[i]])
            n_qvalue.append(noun_rank[i][n_idx[i]])
        v_qvalue, n_qvalue = torch.stack(v_qvalue), torch.stack(n_qvalue)
        v_idx, n_idx = to_pt(np.array(v_idx), self.use_cuda), to_pt(np.array(n_idx), self.use_cuda)
        return v_qvalue, v_idx, n_qvalue, n_idx
def train(config):
    # train env
    print('Setting up TextWorld environment...')
    batch_size = config['training']['scheduling']['batch_size']
    env_id = gym_textworld.make_batch(env_id=config['general']['env_id'],
                                      batch_size=batch_size,
                                      parallel=True)
    env = gym.make(env_id)
    env.seed(config['general']['random_seed'])

    # valid and test env
    run_test = config['general']['run_test']
    if run_test:
        test_batch_size = config['training']['scheduling']['test_batch_size']
        # valid
        valid_env_name = config['general']['valid_env_id']

        valid_env_id = gym_textworld.make_batch(env_id=valid_env_name,
                                                batch_size=test_batch_size,
                                                parallel=True)
        valid_env = gym.make(valid_env_id)
        valid_env.seed(config['general']['random_seed'])

        # test
        test_env_name_list = config['general']['test_env_id']
        assert isinstance(test_env_name_list, list)

        test_env_id_list = [
            gym_textworld.make_batch(env_id=item,
                                     batch_size=test_batch_size,
                                     parallel=True)
            for item in test_env_name_list
        ]
        test_env_list = [
            gym.make(test_env_id) for test_env_id in test_env_id_list
        ]
        for i in range(len(test_env_list)):
            test_env_list[i].seed(config['general']['random_seed'])
    print('Done.')

    # Set the random seed manually for reproducibility.
    np.random.seed(config['general']['random_seed'])
    torch.manual_seed(config['general']['random_seed'])
    if torch.cuda.is_available():
        if not config['general']['use_cuda']:
            logger.warning(
                "WARNING: CUDA device detected but 'use_cuda: false' found in config.yaml"
            )
        else:
            torch.backends.cudnn.deterministic = True
            torch.cuda.manual_seed(config['general']['random_seed'])
    else:
        config['general']['use_cuda'] = False  # Disable CUDA.
    revisit_counting = config['general']['revisit_counting']
    replay_batch_size = config['general']['replay_batch_size']
    replay_memory_capacity = config['general']['replay_memory_capacity']
    replay_memory_priority_fraction = config['general'][
        'replay_memory_priority_fraction']

    word_vocab = dict2list(env.observation_space.id2w)
    word2id = {}
    for i, w in enumerate(word_vocab):
        word2id[w] = i

    # collect all nouns
    verb_list = ["go", "take"]
    object_name_list = ["east", "west", "north", "south", "coin"]
    verb_map = [word2id[w] for w in verb_list if w in word2id]
    noun_map = [word2id[w] for w in object_name_list if w in word2id]
    agent = RLAgent(
        config,
        word_vocab,
        verb_map,
        noun_map,
        replay_memory_capacity=replay_memory_capacity,
        replay_memory_priority_fraction=replay_memory_priority_fraction)

    init_learning_rate = config['training']['optimizer']['learning_rate']

    exp_dir = get_experiment_dir(config)
    summary = SummaryWriter(exp_dir)

    parameters = filter(lambda p: p.requires_grad, agent.model.parameters())
    if config['training']['optimizer']['step_rule'] == 'sgd':
        optimizer = torch.optim.SGD(parameters, lr=init_learning_rate)
    elif config['training']['optimizer']['step_rule'] == 'adam':
        optimizer = torch.optim.Adam(parameters, lr=init_learning_rate)

    log_every = 100
    reward_avg = SlidingAverage('reward avg', steps=log_every)
    step_avg = SlidingAverage('step avg', steps=log_every)
    loss_avg = SlidingAverage('loss avg', steps=log_every)

    # save & reload checkpoint only in 0th agent
    best_avg_reward = -10000
    best_avg_step = 10000

    # step penalty
    discount_gamma = config['general']['discount_gamma']
    provide_prev_action = config['general']['provide_prev_action']

    # epsilon greedy
    epsilon_anneal_epochs = config['general']['epsilon_anneal_epochs']
    epsilon_anneal_from = config['general']['epsilon_anneal_from']
    epsilon_anneal_to = config['general']['epsilon_anneal_to']

    # counting reward
    revisit_counting_lambda_anneal_epochs = config['general'][
        'revisit_counting_lambda_anneal_epochs']
    revisit_counting_lambda_anneal_from = config['general'][
        'revisit_counting_lambda_anneal_from']
    revisit_counting_lambda_anneal_to = config['general'][
        'revisit_counting_lambda_anneal_to']

    epsilon = epsilon_anneal_from
    revisit_counting_lambda = revisit_counting_lambda_anneal_from
    for epoch in range(config['training']['scheduling']['epoch']):

        agent.model.train()
        obs, infos = env.reset()
        agent.reset(infos)
        print_command_string, print_rewards = [[] for _ in infos
                                               ], [[] for _ in infos]
        print_interm_rewards = [[] for _ in infos]
        print_rc_rewards = [[] for _ in infos]

        dones = [False] * batch_size
        rewards = None
        avg_loss_in_this_game = []

        new_observation_strings = agent.get_observation_strings(infos)
        if revisit_counting:
            agent.reset_binarized_counter(batch_size)
            revisit_counting_rewards = agent.get_binarized_count(
                new_observation_strings)

        current_game_step = 0
        prev_actions = ["" for _ in range(batch_size)
                        ] if provide_prev_action else None
        input_description, description_id_list = agent.get_game_step_info(
            obs, infos, prev_actions)

        while not all(dones):

            v_idx, n_idx, chosen_strings, state_representation = agent.generate_one_command(
                input_description, epsilon=epsilon)
            obs, rewards, dones, infos = env.step(chosen_strings)
            new_observation_strings = agent.get_observation_strings(infos)
            if provide_prev_action:
                prev_actions = chosen_strings
            # counting
            if revisit_counting:
                revisit_counting_rewards = agent.get_binarized_count(
                    new_observation_strings, update=True)
            else:
                revisit_counting_rewards = [0.0 for _ in range(batch_size)]
            agent.revisit_counting_rewards.append(revisit_counting_rewards)
            revisit_counting_rewards = [
                float(format(item, ".3f")) for item in revisit_counting_rewards
            ]

            for i in range(len(infos)):
                print_command_string[i].append(chosen_strings[i])
                print_rewards[i].append(rewards[i])
                print_interm_rewards[i].append(infos[i]["intermediate_reward"])
                print_rc_rewards[i].append(revisit_counting_rewards[i])
            if type(dones) is bool:
                dones = [dones] * batch_size
            agent.rewards.append(rewards)
            agent.dones.append(dones)
            agent.intermediate_rewards.append(
                [info["intermediate_reward"] for info in infos])
            # computer rewards, and push into replay memory
            rewards_np, rewards, mask_np, mask = agent.compute_reward(
                revisit_counting_lambda=revisit_counting_lambda,
                revisit_counting=revisit_counting)

            curr_description_id_list = description_id_list
            input_description, description_id_list = agent.get_game_step_info(
                obs, infos, prev_actions)

            for b in range(batch_size):
                if mask_np[b] == 0:
                    continue
                if replay_memory_priority_fraction == 0.0:
                    # vanilla replay memory
                    agent.replay_memory.push(curr_description_id_list[b],
                                             v_idx[b], n_idx[b], rewards[b],
                                             mask[b], dones[b],
                                             description_id_list[b],
                                             new_observation_strings[b])
                else:
                    # prioritized replay memory
                    is_prior = rewards_np[b] > 0.0
                    agent.replay_memory.push(is_prior,
                                             curr_description_id_list[b],
                                             v_idx[b], n_idx[b], rewards[b],
                                             mask[b], dones[b],
                                             description_id_list[b],
                                             new_observation_strings[b])

            if current_game_step > 0 and current_game_step % config["general"][
                    "update_per_k_game_steps"] == 0:
                policy_loss = agent.update(replay_batch_size,
                                           discount_gamma=discount_gamma)
                if policy_loss is None:
                    continue
                loss = policy_loss
                # Backpropagate
                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
                torch.nn.utils.clip_grad_norm(
                    agent.model.parameters(),
                    config['training']['optimizer']['clip_grad_norm'])
                optimizer.step()  # apply gradients
                avg_loss_in_this_game.append(to_np(policy_loss))
            current_game_step += 1

        agent.finish()
        avg_loss_in_this_game = np.mean(avg_loss_in_this_game)
        reward_avg.add(agent.final_rewards.mean())
        step_avg.add(agent.step_used_before_done.mean())
        loss_avg.add(avg_loss_in_this_game)
        # annealing
        if epoch < epsilon_anneal_epochs:
            epsilon -= (epsilon_anneal_from -
                        epsilon_anneal_to) / float(epsilon_anneal_epochs)
        if epoch < revisit_counting_lambda_anneal_epochs:
            revisit_counting_lambda -= (
                revisit_counting_lambda_anneal_from -
                revisit_counting_lambda_anneal_to
            ) / float(revisit_counting_lambda_anneal_epochs)

        # Tensorboard logging #
        # (1) Log some numbers
        if (epoch + 1
            ) % config["training"]["scheduling"]["logging_frequency"] == 0:
            summary.add_scalar('avg_reward', reward_avg.value, epoch + 1)
            summary.add_scalar('curr_reward', agent.final_rewards.mean(),
                               epoch + 1)
            summary.add_scalar('curr_interm_reward',
                               agent.final_intermediate_rewards.mean(),
                               epoch + 1)
            summary.add_scalar('curr_counting_reward',
                               agent.final_counting_rewards.mean(), epoch + 1)
            summary.add_scalar('avg_step', step_avg.value, epoch + 1)
            summary.add_scalar('curr_step', agent.step_used_before_done.mean(),
                               epoch + 1)
            summary.add_scalar('loss_avg', loss_avg.value, epoch + 1)
            summary.add_scalar('curr_loss', avg_loss_in_this_game, epoch + 1)

        msg = 'E#{:03d}, R={:.3f}/{:.3f}/IR{:.3f}/CR{:.3f}, S={:.3f}/{:.3f}, L={:.3f}/{:.3f}, epsilon={:.4f}, lambda_counting={:.4f}'
        msg = msg.format(epoch, np.mean(reward_avg.value),
                         agent.final_rewards.mean(),
                         agent.final_intermediate_rewards.mean(),
                         agent.final_counting_rewards.mean(),
                         np.mean(step_avg.value),
                         agent.step_used_before_done.mean(),
                         np.mean(loss_avg.value), avg_loss_in_this_game,
                         epsilon, revisit_counting_lambda)
        if (epoch + 1
            ) % config["training"]["scheduling"]["logging_frequency"] == 0:
            print("=========================================================")
            for prt_cmd, prt_rew, prt_int_rew, prt_rc_rew in zip(
                    print_command_string, print_rewards, print_interm_rewards,
                    print_rc_rewards):
                print("------------------------------")
                print(prt_cmd)
                print(prt_rew)
                print(prt_int_rew)
                print(prt_rc_rew)
        print(msg)
        # test on a different set of games
        if run_test and (epoch + 1) % config["training"]["scheduling"][
                "logging_frequency"] == 0:
            valid_R, valid_IR, valid_S = test(config, valid_env, agent,
                                              test_batch_size, word2id)
            summary.add_scalar('valid_reward', valid_R, epoch + 1)
            summary.add_scalar('valid_interm_reward', valid_IR, epoch + 1)
            summary.add_scalar('valid_step', valid_S, epoch + 1)

            # save & reload checkpoint by best valid performance
            model_checkpoint_path = config['training']['scheduling'][
                'model_checkpoint_path']
            if valid_R > best_avg_reward or (valid_R == best_avg_reward
                                             and valid_S < best_avg_step):
                best_avg_reward = valid_R
                best_avg_step = valid_S
                torch.save(agent.model.state_dict(), model_checkpoint_path)
                print("========= saved checkpoint =========")
                for test_id in range(len(test_env_list)):
                    R, IR, S = test(config, test_env_list[test_id], agent,
                                    test_batch_size, word2id)
                    summary.add_scalar('test_reward_' + str(test_id), R,
                                       epoch + 1)
                    summary.add_scalar('test_interm_reward_' + str(test_id),
                                       IR, epoch + 1)
                    summary.add_scalar('test_step_' + str(test_id), S,
                                       epoch + 1)