Example #1
0
    def admissible_commands_recurrent_teacher_force(self, seq_observation_strings, seq_task_desc_strings, seq_action_candidate_list, seq_target_indices, contains_first_step=False):
        loss_list = []
        previous_dynamics = None
        h_td, td_mask = self.encode(seq_task_desc_strings[0], use_model="online")
        for step_no in range(self.dagger_replay_sample_history_length):
            expert_indicies = to_pt(np.array(seq_target_indices[step_no]), enable_cuda=self.use_cuda, type='long')
            h_obs, obs_mask = self.encode(seq_observation_strings[step_no], use_model="online")
            action_scores, _, current_dynamics = self.action_scoring(seq_action_candidate_list[step_no], h_obs, obs_mask, h_td, td_mask,
                                                                     previous_dynamics, use_model="online")
            previous_dynamics = current_dynamics
            if (not contains_first_step) and step_no < self.dagger_replay_sample_update_from:
                previous_dynamics = previous_dynamics.detach()
                continue

            # softmax and cross-entropy
            loss = self.cross_entropy_loss(action_scores, expert_indicies)
            loss_list.append(loss)
        loss = torch.stack(loss_list).mean()
        # Backpropagate
        self.online_net.zero_grad()
        self.optimizer.zero_grad()
        loss.backward()
        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), self.clip_grad_norm)
        self.optimizer.step()  # apply gradients
        return to_np(loss)
Example #2
0
    def command_generation_greedy_generation(self, observation_feats,
                                             task_desc_strings,
                                             previous_dynamics):
        with torch.no_grad():
            batch_size = len(observation_feats)

            aggregated_obs_feat = self.aggregate_feats_seq(observation_feats)
            h_obs = self.online_net.vision_fc(aggregated_obs_feat)
            h_td, td_mask = self.encode(task_desc_strings, use_model="online")
            h_td_mean = self.online_net.masked_mean(h_td, td_mask).unsqueeze(1)
            h_obs = h_obs.to(h_td_mean.device)
            vision_td = torch.cat((h_obs, h_td_mean),
                                  dim=1)  # batch x k boxes x hid
            vision_td_mask = torch.ones(
                (batch_size,
                 h_obs.shape[1] + h_td_mean.shape[1])).to(h_td_mean.device)

            if self.recurrent:
                averaged_vision_td_representation = self.online_net.masked_mean(
                    vision_td, vision_td_mask)
                current_dynamics = self.online_net.rnncell(
                    averaged_vision_td_representation, previous_dynamics
                ) if previous_dynamics is not None else self.online_net.rnncell(
                    averaged_vision_td_representation)
            else:
                current_dynamics = None

            # greedy generation
            input_target_list = [[self.word2id["[CLS]"]]
                                 for i in range(batch_size)]
            eos = np.zeros(batch_size)
            for _ in range(self.max_target_length):

                input_target = copy.deepcopy(input_target_list)
                input_target = pad_sequences(
                    input_target, maxlen=max_len(input_target)).astype('int32')
                input_target = to_pt(input_target, self.use_cuda)
                target_mask = compute_mask(
                    input_target)  # mask of ground truth should be the same
                pred = self.online_net.vision_decode(
                    input_target, target_mask, vision_td, vision_td_mask,
                    current_dynamics)  # batch x target_length x vocab
                # pointer softmax
                pred = to_np(pred[:, -1])  # batch x vocab
                pred = np.argmax(pred, -1)  # batch
                for b in range(batch_size):
                    new_stuff = [pred[b]] if eos[b] == 0 else []
                    input_target_list[b] = input_target_list[b] + new_stuff
                    if pred[b] == self.word2id["[SEP]"]:
                        eos[b] = 1
                if np.sum(eos) == batch_size:
                    break
            res = [self.tokenizer.decode(item) for item in input_target_list]
            res = [
                item.replace("[CLS]", "").replace("[SEP]", "").strip()
                for item in res
            ]
            res = [item.replace(" in / on ", " in/on ") for item in res]
            return res, current_dynamics
Example #3
0
    def admissible_commands_teacher_force(self, observation_strings, task_desc_strings, action_candidate_list, target_indices):
        expert_indicies = to_pt(np.array(target_indices), enable_cuda=self.use_cuda, type='long')

        h_obs, obs_mask = self.encode(observation_strings, use_model="online")
        h_td, td_mask = self.encode(task_desc_strings, use_model="online")

        action_scores, _, _ = self.action_scoring(action_candidate_list,
                                                  h_obs, obs_mask,
                                                  h_td, td_mask,
                                                  None,
                                                  use_model="online")

        # softmax and cross-entropy
        loss = self.cross_entropy_loss(action_scores, expert_indicies)
        # Backpropagate
        self.online_net.zero_grad()
        self.optimizer.zero_grad()
        loss.backward()
        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), self.clip_grad_norm)
        self.optimizer.step()  # apply gradients
        return to_np(loss)
Example #4
0
 def get_word_input_from_ids(self, word_id_list):
     input_word = pad_sequences(word_id_list, maxlen=max_len(word_id_list) + 3, dtype='int32')  # 3 --> see layer.DepthwiseSeparableConv.padding
     input_word = to_pt(input_word, self.use_cuda)
     return input_word
Example #5
0
def train():

    time_1 = datetime.datetime.now()
    config = generic.load_config()
    agent = TextDQNAgent(config)

    env_type = config["env"]["type"]
    id_eval_env, num_id_eval_game = None, 0
    ood_eval_env, num_ood_eval_game = None, 0
    if env_type == "Hybrid":
        thor = getattr(importlib.import_module("environment"),
                       "AlfredThorEnv")(config)
        tw = getattr(importlib.import_module("environment"),
                     "AlfredTWEnv")(config)

        thor_env = thor.init_env(batch_size=agent.batch_size)
        tw_env = tw.init_env(batch_size=agent.batch_size)
    else:
        alfred_env = getattr(importlib.import_module("environment"),
                             config["env"]["type"])(config, train_eval="train")
        env = alfred_env.init_env(batch_size=agent.batch_size)

        if agent.run_eval:
            # in distribution
            if config['dataset']['eval_id_data_path'] is not None:
                alfred_env = getattr(
                    importlib.import_module("environment"),
                    config["general"]["evaluate"]["env"]["type"])(
                        config, train_eval="eval_in_distribution")
                id_eval_env = alfred_env.init_env(
                    batch_size=agent.eval_batch_size)
                num_id_eval_game = alfred_env.num_games
            # out of distribution
            if config['dataset']['eval_ood_data_path'] is not None:
                alfred_env = getattr(
                    importlib.import_module("environment"),
                    config["general"]["evaluate"]["env"]["type"])(
                        config, train_eval="eval_out_of_distribution")
                ood_eval_env = alfred_env.init_env(
                    batch_size=agent.eval_batch_size)
                num_ood_eval_game = alfred_env.num_games

    output_dir = config["general"]["save_path"]
    data_dir = config["general"]["save_path"]

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # visdom
    if config["general"]["visdom"]:
        import visdom
        viz = visdom.Visdom()
        reward_win, step_win = None, None
        dqn_loss_win = None
        viz_game_points, viz_step, viz_overall_rewards = [], [], []
        viz_id_eval_game_points, viz_id_eval_step = [], []
        viz_ood_eval_game_points, viz_ood_eval_step = [], []
        viz_dqn_loss = []

    step_in_total = 0
    episode_no = 0
    running_avg_game_points = HistoryScoreCache(capacity=500)
    running_avg_overall_rewards = HistoryScoreCache(capacity=500)
    running_avg_game_steps = HistoryScoreCache(capacity=500)
    running_avg_dqn_loss = HistoryScoreCache(capacity=500)

    json_file_name = agent.experiment_tag.replace(" ", "_")
    best_performance_so_far, best_ood_performance_so_far = 0.0, 0.0
    episodic_counting_memory = EpisodicCountingMemory(
    )  # episodic counting based memory
    obj_centric_episodic_counting_memory = ObjCentricEpisodicMemory()

    # load model from checkpoint
    if agent.load_pretrained:
        if os.path.exists(data_dir + "/" + agent.load_from_tag + ".pt"):
            agent.load_pretrained_model(data_dir + "/" + agent.load_from_tag +
                                        ".pt")
            agent.update_target_net()

    while (True):
        if episode_no > agent.max_episode:
            break

        # hybrid env switching
        if env_type == "Hybrid":
            if random.uniform(0, 1) < config["env"]["hybrid_tw_prob"]:
                env = tw_env
            else:
                env = thor_env

        np.random.seed(episode_no)
        env.seed(episode_no)
        obs, infos = env.reset()
        batch_size = len(obs)

        agent.train()
        agent.init(batch_size)
        episodic_counting_memory.reset(
        )  # reset episodic counting based memory
        obj_centric_episodic_counting_memory.reset(
        )  # reset object centric episodic counting based memory
        previous_dynamics = None

        chosen_actions = []
        prev_step_dones, prev_rewards = [], []
        for _ in range(batch_size):
            chosen_actions.append("restart")
            prev_step_dones.append(0.0)
            prev_rewards.append(0.0)

        observation_strings = list(obs)
        task_desc_strings, observation_strings = agent.get_task_and_obs(
            observation_strings)
        task_desc_strings = agent.preprocess_task(task_desc_strings)
        observation_strings = agent.preprocess_observation(observation_strings)
        first_sight_strings = copy.deepcopy(observation_strings)
        agent.observation_pool.push_first_sight(first_sight_strings)
        if agent.action_space == "exhaustive":
            action_candidate_list = [
                extract_admissible_commands(intro, obs)
                for intro, obs in zip(first_sight_strings, observation_strings)
            ]
        else:
            action_candidate_list = list(infos["admissible_commands"])
        action_candidate_list = agent.preprocess_action_candidates(
            action_candidate_list)
        observation_only = observation_strings
        observation_strings = [
            item + " [SEP] " + a
            for item, a in zip(observation_strings, chosen_actions)
        ]  # appending the chosen action at previous step into the observation
        episodic_counting_memory.push(
            observation_only)  # update init observation into memory
        obj_centric_episodic_counting_memory.push(observation_only)

        # it requires to store sequences of transitions into memory with order,
        # so we use a cache to keep what agents returns, and push them into memory
        # altogether in the end of game.
        transition_cache = []
        still_running_mask = []
        sequence_game_rewards, sequence_count_rewards, sequence_novel_object_rewards, sequence_game_points = [], [], [], []
        print_actions = []

        act_randomly = False if agent.noisy_net else episode_no < agent.learn_start_from_this_episode
        for step_no in range(agent.max_nb_steps_per_episode):
            # push obs into observation pool
            agent.observation_pool.push_batch(observation_strings)
            # get most recent k observations
            most_recent_observation_strings = agent.observation_pool.get()
            if agent.noisy_net:
                agent.reset_noise()  # Draw a new set of noisy weights

            # predict actions
            if agent.action_space == "generation":
                chosen_actions, chosen_indices, current_dynamics = agent.command_generation_act(
                    most_recent_observation_strings,
                    task_desc_strings,
                    previous_dynamics,
                    random=act_randomly)
            elif agent.action_space == "beam_search_choice":
                chosen_actions, chosen_indices, current_dynamics, action_candidate_list = agent.beam_search_choice_act(
                    most_recent_observation_strings,
                    task_desc_strings,
                    previous_dynamics,
                    random=act_randomly)
            elif agent.action_space in ["admissible", "exhaustive"]:
                chosen_actions, chosen_indices, current_dynamics = agent.admissible_commands_act(
                    most_recent_observation_strings,
                    task_desc_strings,
                    action_candidate_list,
                    previous_dynamics,
                    random=act_randomly)
            else:
                raise NotImplementedError()

            replay_info = [
                most_recent_observation_strings, task_desc_strings,
                action_candidate_list, chosen_indices
            ]
            transition_cache.append(replay_info)
            obs, _, dones, infos = env.step(chosen_actions)
            scores = [float(item) for item in infos["won"]]
            dones = [float(item) for item in dones]

            observation_strings = list(obs)
            observation_strings = agent.preprocess_observation(
                observation_strings)
            if agent.action_space == "exhaustive":
                action_candidate_list = [
                    extract_admissible_commands(intro, obs) for intro, obs in
                    zip(first_sight_strings, observation_strings)
                ]
            else:
                action_candidate_list = list(infos["admissible_commands"])
            action_candidate_list = agent.preprocess_action_candidates(
                action_candidate_list)
            observation_only = observation_strings
            observation_strings = [
                item + " [SEP] " + a
                for item, a in zip(observation_strings, chosen_actions)
            ]  # appending the chosen action at previous step into the observation
            seeing_new_states = episodic_counting_memory.is_a_new_state(
                observation_only)
            seeing_new_objects = obj_centric_episodic_counting_memory.get_object_novelty_reward(
                observation_only)
            episodic_counting_memory.push(
                observation_only)  # update new observation into memory
            obj_centric_episodic_counting_memory.push(
                observation_only)  # update new observation into memory
            previous_dynamics = current_dynamics

            if agent.noisy_net and step_in_total % agent.update_per_k_game_steps == 0:
                agent.reset_noise()  # Draw a new set of noisy weights

            if episode_no >= agent.learn_start_from_this_episode and step_in_total % agent.update_per_k_game_steps == 0:
                dqn_loss, _ = agent.update_dqn()
                if dqn_loss is not None:
                    running_avg_dqn_loss.push(dqn_loss)

            if step_no == agent.max_nb_steps_per_episode - 1:
                # terminate the game because DQN requires one extra step
                dones = [1.0 for _ in dones]

            step_in_total += 1
            still_running = [1.0 - float(item)
                             for item in prev_step_dones]  # list of float
            prev_step_dones = dones
            step_rewards = [
                float(curr) - float(prev)
                for curr, prev in zip(scores, prev_rewards)
            ]  # list of float
            count_rewards = [
                r * agent.count_reward_lambda for r in seeing_new_states
            ]  # list of float
            novel_object_rewards = [
                r * agent.novel_object_reward_lambda
                for r in seeing_new_objects
            ]  # list of novel object rewards
            sequence_game_points.append(copy.copy(step_rewards))
            prev_rewards = scores
            still_running_mask.append(still_running)
            sequence_game_rewards.append(step_rewards)
            sequence_count_rewards.append(count_rewards)
            sequence_novel_object_rewards.append(novel_object_rewards)
            print_actions.append(
                chosen_actions[0] if still_running[0] else "--")

            # if all ended, break
            if np.sum(still_running) == 0:
                break

        still_running_mask_np = np.array(still_running_mask)
        game_rewards_np = np.array(
            sequence_game_rewards) * still_running_mask_np  # step x batch
        count_rewards_np = np.array(
            sequence_count_rewards) * still_running_mask_np  # step x batch
        novel_object_rewards_np = np.array(
            sequence_novel_object_rewards) * still_running_mask_np
        game_points_np = np.array(
            sequence_game_points) * still_running_mask_np  # step x batch
        game_rewards_pt = generic.to_pt(game_rewards_np,
                                        enable_cuda=False,
                                        type='float')  # step x batch
        count_rewards_pt = generic.to_pt(count_rewards_np,
                                         enable_cuda=False,
                                         type='float')  # step x batch
        novel_object_rewards_pt = generic.to_pt(novel_object_rewards_np,
                                                enable_cuda=False,
                                                type='float')

        # push experience into replay buffer (dqn)
        avg_reward_in_replay_buffer = agent.dqn_memory.get_avg_rewards()
        for b in range(game_rewards_np.shape[1]):
            if still_running_mask_np.shape[
                    0] == agent.max_nb_steps_per_episode and still_running_mask_np[
                        -1][b] != 0:
                # need to pad one transition
                avg_reward = game_rewards_np[:, b].tolist() + [0.0]
                _need_pad = True
            else:
                avg_reward = game_rewards_np[:, b]
                _need_pad = False
            avg_reward = np.mean(avg_reward)
            is_prior = avg_reward >= avg_reward_in_replay_buffer

            mem = []
            for i in range(game_rewards_np.shape[0]):
                observation_strings, task_strings, action_candidate_list, chosen_indices = transition_cache[
                    i]
                mem.append([
                    observation_strings[b], task_strings[b],
                    action_candidate_list[b], chosen_indices[b],
                    game_rewards_pt[i][b], count_rewards_pt[i][b],
                    novel_object_rewards_pt[i][b]
                ])
                if still_running_mask_np[i][b] == 0.0:
                    break
            if _need_pad:
                observation_strings, task_strings, action_candidate_list, chosen_indices = transition_cache[
                    -1]
                mem.append([
                    observation_strings[b], task_strings[b],
                    action_candidate_list[b], chosen_indices[b],
                    game_rewards_pt[-1][b] * 0.0,
                    count_rewards_pt[-1][b] * 0.0,
                    novel_object_rewards_pt[-1][b] * 0.0
                ])
            agent.dqn_memory.push(is_prior, avg_reward, mem)

        for b in range(batch_size):
            running_avg_game_points.push(np.sum(game_points_np, 0)[b])
            running_avg_overall_rewards.push(
                np.sum(game_rewards_np, 0)[b] +
                np.sum(count_rewards_np, 0)[b] +
                np.sum(novel_object_rewards_np, 0)[b])
            running_avg_game_steps.push(np.sum(still_running_mask_np, 0)[b])

        # finish game
        agent.finish_of_episode(episode_no, batch_size)
        episode_no += batch_size

        if episode_no < agent.learn_start_from_this_episode:
            continue
        if agent.report_frequency == 0 or (
                episode_no % agent.report_frequency >
            (episode_no - batch_size) % agent.report_frequency):
            continue
        time_2 = datetime.datetime.now()
        print(
            "Episode: {:3d} | time spent: {:s} | dqn loss: {:2.3f} | overall rewards: {:2.3f}/{:2.3f} | game points: {:2.3f}/{:2.3f} | used steps: {:2.3f}/{:2.3f}"
            .format(
                episode_no,
                str(time_2 - time_1).rsplit(".")[0],
                running_avg_dqn_loss.get_avg(),
                np.mean(
                    np.sum(game_rewards_np, 0) + np.sum(count_rewards_np, 0) +
                    np.sum(novel_object_rewards_np, 0)),
                running_avg_overall_rewards.get_avg(),
                np.mean(np.sum(game_points_np, 0)),
                running_avg_game_points.get_avg(),
                np.mean(np.sum(still_running_mask_np, 0)),
                running_avg_game_steps.get_avg()))
        # print(game_id + ":    " + " | ".join(print_actions))
        print(" | ".join(print_actions))

        # evaluate
        id_eval_game_points, id_eval_game_step = 0.0, 0.0
        ood_eval_game_points, ood_eval_game_step = 0.0, 0.0
        if agent.run_eval:
            if id_eval_env is not None:
                id_eval_res = evaluate_dqn(id_eval_env, agent,
                                           num_id_eval_game)
                id_eval_game_points, id_eval_game_step = id_eval_res[
                    'average_points'], id_eval_res['average_steps']
            if ood_eval_env is not None:
                ood_eval_res = evaluate_dqn(ood_eval_env, agent,
                                            num_ood_eval_game)
                ood_eval_game_points, ood_eval_game_step = ood_eval_res[
                    'average_points'], ood_eval_res['average_steps']
            if id_eval_game_points >= best_performance_so_far:
                best_performance_so_far = id_eval_game_points
                agent.save_model_to_path(output_dir + "/" +
                                         agent.experiment_tag + "_id.pt")
            if ood_eval_game_points >= best_ood_performance_so_far:
                best_ood_performance_so_far = ood_eval_game_points
                agent.save_model_to_path(output_dir + "/" +
                                         agent.experiment_tag + "_ood.pt")
        else:
            if running_avg_game_points.get_avg() >= best_performance_so_far:
                best_performance_so_far = running_avg_game_points.get_avg()
                agent.save_model_to_path(output_dir + "/" +
                                         agent.experiment_tag + ".pt")

        # plot using visdom
        if config["general"]["visdom"]:
            viz_game_points.append(running_avg_game_points.get_avg())
            viz_overall_rewards.append(running_avg_overall_rewards.get_avg())
            viz_step.append(running_avg_game_steps.get_avg())
            viz_dqn_loss.append(running_avg_dqn_loss.get_avg())
            viz_id_eval_game_points.append(id_eval_game_points)
            viz_id_eval_step.append(id_eval_game_step)
            viz_ood_eval_game_points.append(ood_eval_game_points)
            viz_ood_eval_step.append(ood_eval_game_step)
            viz_x = np.arange(len(viz_game_points)).tolist()

            if reward_win is None:
                reward_win = viz.line(X=viz_x,
                                      Y=viz_game_points,
                                      opts=dict(title=agent.experiment_tag +
                                                "_game_points"),
                                      name="game points")
                viz.line(X=viz_x,
                         Y=viz_overall_rewards,
                         opts=dict(title=agent.experiment_tag +
                                   "_overall_rewards"),
                         win=reward_win,
                         update='append',
                         name="overall rewards")
                viz.line(X=viz_x,
                         Y=viz_id_eval_game_points,
                         opts=dict(title=agent.experiment_tag +
                                   "_id_eval_game_points"),
                         win=reward_win,
                         update='append',
                         name="id eval game points")
                viz.line(X=viz_x,
                         Y=viz_ood_eval_game_points,
                         opts=dict(title=agent.experiment_tag +
                                   "_ood_eval_game_points"),
                         win=reward_win,
                         update='append',
                         name="ood eval game points")
            else:
                viz.line(X=[len(viz_game_points) - 1],
                         Y=[viz_game_points[-1]],
                         opts=dict(title=agent.experiment_tag +
                                   "_game_points"),
                         win=reward_win,
                         update='append',
                         name="game points")
                viz.line(X=[len(viz_overall_rewards) - 1],
                         Y=[viz_overall_rewards[-1]],
                         opts=dict(title=agent.experiment_tag +
                                   "_overall_rewards"),
                         win=reward_win,
                         update='append',
                         name="overall rewards")
                viz.line(X=[len(viz_id_eval_game_points) - 1],
                         Y=[viz_id_eval_game_points[-1]],
                         opts=dict(title=agent.experiment_tag +
                                   "_id_eval_game_points"),
                         win=reward_win,
                         update='append',
                         name="id eval game points")
                viz.line(X=[len(viz_ood_eval_game_points) - 1],
                         Y=[viz_ood_eval_game_points[-1]],
                         opts=dict(title=agent.experiment_tag +
                                   "_ood_eval_game_points"),
                         win=reward_win,
                         update='append',
                         name="ood eval game points")

            if step_win is None:
                step_win = viz.line(X=viz_x,
                                    Y=viz_step,
                                    opts=dict(title=agent.experiment_tag +
                                              "_step"),
                                    name="step")
                viz.line(X=viz_x,
                         Y=viz_id_eval_step,
                         opts=dict(title=agent.experiment_tag +
                                   "_id_eval_step"),
                         win=step_win,
                         update='append',
                         name="id eval step")
                viz.line(X=viz_x,
                         Y=viz_ood_eval_step,
                         opts=dict(title=agent.experiment_tag +
                                   "_ood_eval_step"),
                         win=step_win,
                         update='append',
                         name="ood eval step")
            else:
                viz.line(X=[len(viz_step) - 1],
                         Y=[viz_step[-1]],
                         opts=dict(title=agent.experiment_tag + "_step"),
                         win=step_win,
                         update='append',
                         name="step")
                viz.line(X=[len(viz_id_eval_step) - 1],
                         Y=[viz_id_eval_step[-1]],
                         opts=dict(title=agent.experiment_tag +
                                   "_id_eval_step"),
                         win=step_win,
                         update='append',
                         name="id eval step")
                viz.line(X=[len(viz_ood_eval_step) - 1],
                         Y=[viz_ood_eval_step[-1]],
                         opts=dict(title=agent.experiment_tag +
                                   "_ood_eval_step"),
                         win=step_win,
                         update='append',
                         name="ood eval step")

            if dqn_loss_win is None:
                dqn_loss_win = viz.line(X=viz_x,
                                        Y=viz_dqn_loss,
                                        opts=dict(title=agent.experiment_tag +
                                                  "_dqn_loss"),
                                        name="dqn loss")
            else:
                viz.line(X=[len(viz_dqn_loss) - 1],
                         Y=[viz_dqn_loss[-1]],
                         opts=dict(title=agent.experiment_tag + "_dqn_loss"),
                         win=dqn_loss_win,
                         update='append',
                         name="dqn loss")

        # write accuracies down into file
        _s = json.dumps({
            "time spent":
            str(time_2 - time_1).rsplit(".")[0],
            "dqn loss":
            str(running_avg_dqn_loss.get_avg()),
            "overall rewards":
            str(running_avg_overall_rewards.get_avg()),
            "train game points":
            str(running_avg_game_points.get_avg()),
            "train steps":
            str(running_avg_game_steps.get_avg()),
            "id eval game points":
            str(id_eval_game_points),
            "id eval steps":
            str(id_eval_game_step),
            "ood eval game points":
            str(ood_eval_game_points),
            "ood eval steps":
            str(ood_eval_game_step)
        })
        with open(output_dir + "/" + json_file_name + '.json',
                  'a+') as outfile:
            outfile.write(_s + '\n')
            outfile.flush()
    agent.save_model_to_path(output_dir + "/" + agent.experiment_tag +
                             "_final.pt")
Example #6
0
    def command_generation_beam_search_generation(self, observation_strings, task_desc_strings, previous_dynamics):
        with torch.no_grad():

            batch_size = len(observation_strings)
            beam_width = self.beam_width
            if beam_width == 1:
                res, current_dynamics = self.command_generation_greedy_generation(observation_strings, task_desc_strings, previous_dynamics)
                res = [[item] for item in res]
                return res, current_dynamics
            generate_top_k = self.generate_top_k
            res = []

            input_obs = self.get_word_input(observation_strings)
            h_obs, obs_mask = self.encode(observation_strings, use_model="online")
            h_td, td_mask = self.encode(task_desc_strings, use_model="online")
            aggregated_obs_representation = self.online_net.aggretate_information(h_obs, obs_mask, h_td, td_mask)  # batch x obs_length x hid

            if self.recurrent:
                averaged_representation = self.online_net.masked_mean(aggregated_obs_representation, obs_mask)  # batch x hid
                current_dynamics = self.online_net.rnncell(averaged_representation, previous_dynamics) if previous_dynamics is not None else self.online_net.rnncell(averaged_representation)
            else:
                current_dynamics = None

            for b in range(batch_size):

                # starts from CLS tokens
                __input_target_list = [self.word2id["[CLS]"]]
                __input_obs = input_obs[b: b + 1]  # 1 x obs_len
                __obs_mask = obs_mask[b: b + 1]  # 1 x obs_len
                __aggregated_obs_representation = aggregated_obs_representation[b: b + 1]  # 1 x obs_len x hid
                if current_dynamics is not None:
                    __current_dynamics = current_dynamics[b: b + 1]  # 1 x hid
                else:
                    __current_dynamics = None
                ended_nodes = []

                # starting node -  previous node, input target, logp, length
                node = BeamSearchNode(None, __input_target_list, 0, 1)
                nodes_queue = PriorityQueue()
                # start the queue
                nodes_queue.put((node.val, node))
                queue_size = 1

                while(True):
                    # give up when decoding takes too long
                    if queue_size > 2000:
                        break

                    # fetch the best node
                    score, n = nodes_queue.get()
                    __input_target_list = n.input_target

                    if (n.input_target[-1] == self.word2id["[SEP]"] or n.length >= self.max_target_length) and n.previous_node != None:
                        ended_nodes.append((score, n))
                        # if we reached maximum # of sentences required
                        if len(ended_nodes) >= generate_top_k:
                            break
                        else:
                            continue

                    input_target = pad_sequences([__input_target_list], dtype='int32')
                    input_target = to_pt(input_target, self.use_cuda)
                    target_mask = compute_mask(input_target)
                    # decode for one step using decoder
                    pred = self.online_net.decode(input_target, target_mask, __aggregated_obs_representation, __obs_mask, __current_dynamics, __input_obs)  # 1 x target_length x vocab
                    pred = pred[0][-1].cpu()
                    gt_zero = torch.gt(pred, 0.0).float()  # vocab
                    epsilon = torch.le(pred, 0.0).float() * 1e-8  # vocab
                    log_pred = torch.log(pred + epsilon) * gt_zero  # vocab

                    top_beam_width_log_probs, top_beam_width_indicies = torch.topk(log_pred, beam_width)
                    next_nodes = []

                    for new_k in range(beam_width):
                        pos = top_beam_width_indicies[new_k]
                        log_p = top_beam_width_log_probs[new_k].item()
                        node = BeamSearchNode(n, __input_target_list + [pos], n.log_prob + log_p, n.length + 1)
                        next_nodes.append((node.val, node))

                    # put them into queue
                    for i in range(len(next_nodes)):
                        score, nn = next_nodes[i]
                        nodes_queue.put((score, nn))
                    # increase qsize
                    queue_size += len(next_nodes) - 1

                # choose n best paths
                if len(ended_nodes) == 0:
                    ended_nodes = [nodes_queue.get() for _ in range(generate_top_k)]

                utterances = []
                for score, n in sorted(ended_nodes, key=operator.itemgetter(0)):
                    utte = n.input_target
                    utte_string = self.tokenizer.decode(utte)
                    utterances.append(utte_string)
                utterances = [item.replace("[CLS]", "").replace("[SEP]", "").strip() for item in utterances]
                utterances = [item.replace(" in / on ", " in/on " ) for item in utterances]
                res.append(utterances)
            return res, current_dynamics