def __init__(self, shared_model, local_model, action_space, meta_data_util, config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_model = shared_model
        self.local_model = local_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]
        self.logger = None

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(self.local_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(self.local_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(self.local_model, num_objects=67)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(self.local_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(self.local_model)
            self.goal_prediction_loss = None

        self.optimizer = optim.Adam(shared_model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_model, self.local_model, self.calc_loss,
                                  self.optimizer, self.config, self.constants, self.tensorboard)
    def __init__(self, shared_navigator_model, local_navigator_model,
                 shared_predictor_model, local_predictor_model, action_space,
                 meta_data_util, config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_navigator_model = shared_navigator_model
        self.local_navigator_model = local_navigator_model
        self.shared_predictor_model = shared_predictor_model
        self.local_predictor_model = local_predictor_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.ratio = None
        self.epoch = 0
        self.entropy_coef = constants["entropy_coefficient"]

        self.image_channels, self.image_height, self.image_width = shared_navigator_model.image_module.get_final_dimension(
        )

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.local_navigator_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.local_navigator_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.local_navigator_model,
                num_objects=67,
                camera_angle=60,
                image_height=self.image_height,
                image_width=self.image_width,
                object_height=0)  # -2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.local_navigator_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(
                self.local_navigator_model, self.image_height,
                self.image_width)
            self.goal_prediction_loss = None

        parameters = self.shared_navigator_model.get_parameters()
        parameters.extend(self.shared_predictor_model.get_parameters())

        self.optimizer = optim.Adam(parameters, lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_navigator_model,
                                  self.local_navigator_model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)
    def __init__(self, model, action_space, meta_data_util, config, constants,
                 tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.epoch = 0
        self.global_id = 0
        self.entropy_coef = constants["entropy_coefficient"]
        self.final_num_channels, self.final_height, self.final_width = model.image_module.get_final_dimension(
        )

        self.ignore_none = False
        self.inference_procedure = GoalPredictionHouseSingle360ImageSupervisedLearningFromDisk.MODE

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = None
            # self.object_detection_loss_calculator = ObjectPixelIdentification(
            #     self.model, num_objects=67, camera_angle=60, image_height=self.final_height,
            #     image_width=self.final_width, object_height=0)  # -2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(
                self.model, self.final_height, self.final_width)
            self.goal_prediction_loss = None

        self.cross_entropy_loss = None
        self.dist_loss = None
        self.object_detection_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)

        logging.info("Created Single Image goal predictor with ignore_none %r",
                     self.ignore_none)
예제 #4
0
    def __init__(self, model, action_space, meta_data_util, config, constants,
                 tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.num_client = config["num_client"]
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]
        self.beta = 0.9
        self.beta_exp_decay = 0.9
        logging.info(
            "DAGGER: using starting beta of %r and beta exp decay of %r",
            self.beta, self.beta_exp_decay)

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(
                self.model, num_objects=67)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.model)
            self.symbolic_language_prediction_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)
class TmpAsynchronousContextualBandit(AbstractLearning):
    """ Perform Contextual Bandit learning (Kakade and Langford (circa 2006) & Misra, Langford and Artzi EMNLP 2017) """

    def __init__(self, shared_model, local_model, action_space, meta_data_util, config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_model = shared_model
        self.local_model = local_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]
        self.logger = None

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(self.local_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(self.local_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(self.local_model, num_objects=67)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(self.local_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(self.local_model)
            self.goal_prediction_loss = None

        self.optimizer = optim.Adam(shared_model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_model, self.local_model, self.calc_loss,
                                  self.optimizer, self.config, self.constants, self.tensorboard)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        immediate_rewards = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            immediate_rewards.append(replay_item.get_reward())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))
        immediate_rewards = cuda_var(torch.from_numpy(np.array(immediate_rewards)).float())

        # self.logger.log("Learning from Log Probabilities is %r " % log_probabilities.data.cpu().numpy())
        # self.logger.log("Learning from Action Batch is %r " % action_batch.data.cpu().numpy())
        # self.logger.log("Learning from Immediate Rewards is %r " % immediate_rewards.data.cpu().numpy())

        # num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        chosen_log_probs = model_log_prob_batch.gather(1, action_batch.view(-1, 1))
        reward_log_probs = immediate_rewards * chosen_log_probs.view(-1)

        # self.logger.log("Learning from Chosen Log Probs is %r " % chosen_log_probs.data.cpu().numpy())
        # self.logger.log("Learning from Reward Log Probs is %r " % reward_log_probs.data.cpu().numpy())

        model_prob_batch = torch.exp(model_log_prob_batch)
        # mini_batch_action_distribution = torch.mean(model_prob_batch, 0)
        # self.cross_entropy = -torch.sum(gold_distribution * torch.log(mini_batch_action_distribution))
        self.entropy = -torch.sum(torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(reward_log_probs)

        # self.logger.log("Objective is %r and entropy is %r and entropy coef is %r " %
        #                 (objective, self.entropy, self.entropy_coef))

        # Essentially we want the objective to increase and cross entropy to decrease
        loss = -objective - self.entropy_coef * self.entropy
        self.ratio = torch.abs(objective)/(self.entropy_coef * self.entropy)  # we want the ratio to be high

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants["action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(batch_replay_items)
            self.object_detection_loss = self.constants["object_detection_coeff"] * self.object_detection_loss
            loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        if self.config["do_goal_prediction"]:
            self.goal_prediction_loss = self.goal_prediction_calculator.calc_loss(batch_replay_items)
            self.goal_prediction_loss = self.constants["goal_prediction_coeff"] * self.goal_prediction_loss
            loss = loss + self.goal_prediction_loss
        else:
            self.goal_prediction_loss = None

        return loss

    @staticmethod
    def convert_text_to_indices(text, vocab, ignore_case=True):

        # Tokenize the text
        print ("instruction ", text)
        token_seq = nltk.word_tokenize(text)

        indices = []

        for token in token_seq:
            if ignore_case:
                ltoken = token.lower()
            else:
                ltoken = token
            if ltoken in vocab:
                indices.append(vocab[ltoken])
            else:
                indices.append(vocab["$UNK$"])

        return indices

    @staticmethod
    def convert_indices_to_text(indices, vocab):
        return " ".join([vocab[index] for index in indices])

    def get_goal(self, metadata):

        if metadata["goal-screen"] is None:
            return None, None, None, None

        left, bottom, depth = metadata["goal-screen"]

        if 0.01 < left < self.config["image_width"] and 0.01 < bottom < self.config["image_height"] and depth > 0.01:

            scaled_left = left / float(self.config["image_width"])
            scaled_top = 1.0 - bottom / float(self.config["image_height"])

            row_real = self.config["num_manipulation_row"] * scaled_top
            col_real = self.config["num_manipulation_col"] * scaled_left
            row, col = round(row_real), round(col_real)

            if row < 0:
                row = 0
            elif row >= self.config["num_manipulation_row"]:
                row = self.config["num_manipulation_row"] - 1
            if col < 0:
                col = 0
            elif col >= self.config["num_manipulation_col"]:
                col = self.config["num_manipulation_col"] - 1

            return row, col, row_real, col_real
        else:
            return None, None, None, None

    @staticmethod
    def do_train(house_id, shared_model, config, action_space, meta_data_util,
                 constants, train_dataset, tune_dataset, experiment,
                 experiment_name, rank, server, logger, model_type, vocab, use_pushover=False):
        try:
            TmpAsynchronousContextualBandit.do_train_(house_id, shared_model, config, action_space, meta_data_util,
                                                      constants, train_dataset, tune_dataset, experiment,
                                                      experiment_name, rank, server, logger, model_type,
                                                      vocab, use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_train_(house_id, shared_model, config, action_space, meta_data_util, constants,
                  train_dataset, tune_dataset, experiment, experiment_name, rank, server,
                  logger, model_type, vocab, use_pushover=False):

        logger.log("In Training...")
        launch_k_unity_builds([config["port"]], "./house_" + str(house_id) + "_elmer.x86_64",
                              arg_str="--config ./AssetsHouse/config" + str(house_id) + ".json",
                              cwd="./simulators/house/")
        logger.log("Launched Builds.")
        server.initialize_server()
        logger.log("Server Initialized.")

        # Test policy
        test_policy = gp.get_argmax_action

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
            logger.log('Created Tensorboard Server.')
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = None
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(config, constants)
        # local_model.train()

        # Create the Agent
        tmp_agent = TmpHouseAgent(server=server,
                                  model=local_model,
                                  test_policy=test_policy,
                                  action_space=action_space,
                                  meta_data_util=meta_data_util,
                                  config=config,
                                  constants=constants)
        logger.log("Created Agent.")

        action_counts = [0] * action_space.num_actions()
        max_epochs = 100000 # constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        if tune_dataset_size > 0:
            # Test on tuning data
            tmp_agent.test(tune_dataset, vocab, tensorboard=tensorboard,
                           logger=logger, pushover_logger=pushover_logger)

        # Create the learner to compute the loss
        learner = TmpAsynchronousContextualBandit(shared_model, local_model, action_space, meta_data_util,
                                                  config, constants, tensorboard)
        # TODO change 2 --- unity launch moved up
        learner.logger = logger

        for epoch in range(1, max_epochs + 1):

            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_model.load_from_state_dict(shared_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logger.log("Done %d out of %d" %(data_point_ix, dataset_size))
                    logger.log("Training data action counts %r" % action_counts)

                num_actions = 0
                max_num_actions = constants["horizon"]
                max_num_actions += constants["max_extra_horizon"]

                image, metadata = tmp_agent.server.reset_receive_feedback(data_point)
                instruction = data_point.get_instruction()
                # instruction_str = TmpAsynchronousContextualBandit.convert_indices_to_text(instruction, vocab)
                # print("Instruction str is ", instruction_str)

                # Pose and Orientation gone TODO change 3
                state = AgentObservedState(instruction=instruction,
                                           config=config,
                                           constants=constants,
                                           start_image=image,
                                           previous_action=None,
                                           data_point=data_point)
                state.goal = learner.get_goal(metadata)

                model_state = None
                batch_replay_items = []
                total_reward = 0
                forced_stop = True

                while num_actions < max_num_actions:

                    # logger.log("Training: Meta Data %r " % metadata)

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, state_feature = \
                        local_model.get_probs(state, model_state)
                    probabilities = list(torch.exp(log_probabilities.data))[0]

                    # Sample action from the probability
                    action = gp.sample_action_from_prob(probabilities)
                    action_counts[action] += 1

                    if action == action_space.get_stop_action_index():
                        forced_stop = False
                        break

                    # Send the action and get feedback
                    image, reward, metadata = tmp_agent.server.send_action_receive_feedback(action)
                    # logger.log("Action is %r, Reward is %r Probability is %r " % (action, reward, probabilities))

                    # Store it in the replay memory list
                    replay_item = ReplayMemoryItem(state, action, reward, log_prob=log_probabilities)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    # Pose and orientation gone, TODO change 4
                    state = state.update(image, action, data_point=data_point)
                    state.goal = learner.get_goal(metadata)

                    num_actions += 1
                    total_reward += reward

                # Send final STOP action and get feedback
                image, reward, metadata = tmp_agent.server.halt_and_receive_feedback()
                total_reward += reward

                # Store it in the replay memory list
                if not forced_stop:
                    # logger.log("Action is Stop, Reward is %r Probability is %r " % (reward, probabilities))
                    replay_item = ReplayMemoryItem(state, action_space.get_stop_action_index(),
                                                   reward, log_prob=log_probabilities)
                    batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:  # 32
                    loss_val = learner.do_update(batch_replay_items)

                    if tensorboard is not None:
                        # cross_entropy = float(learner.cross_entropy.data[0])
                        # tensorboard.log(cross_entropy, loss_val, 0)
                        tensorboard.log_scalar("loss", loss_val)
                        entropy = float(learner.entropy.data[0])/float(num_actions + 1)
                        tensorboard.log_scalar("entropy", entropy)
                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar("Abs_objective_to_entropy_ratio", ratio)
                        tensorboard.log_scalar("total_reward", total_reward)
                        tensorboard.log_scalar("mean navigation error", metadata['mean-navigation-error'])

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(learner.symbolic_language_prediction_loss.data[0])
                            tensorboard.log_scalar("sym_language_prediction_loss", symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss", goal_prediction_loss)

            # Save the model
            local_model.save_model(experiment + "/contextual_bandit_" + str(rank) + "_epoch_" + str(epoch))
            logger.log("Training data action counts %r" % action_counts)

            if tune_dataset_size > 0:
                # Test on tuning data
                tmp_agent.test(tune_dataset, vocab, tensorboard=tensorboard,
                               logger=logger, pushover_logger=pushover_logger)
예제 #6
0
class TmpSupervisedLearning(AbstractLearning):
    """ Perform Supervised Learning """
    def __init__(self, shared_model, local_model, action_space, meta_data_util,
                 config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_model = shared_model
        self.local_model = local_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.local_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.local_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(
                self.local_model, num_objects=67)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.local_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(self.local_model)
            self.goal_prediction_loss = None

        self.optimizer = optim.Adam(shared_model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_model, self.local_model,
                                  self.calc_loss, self.optimizer, self.config,
                                  self.constants, self.tensorboard)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))

        num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        # model_log_prob_batch = self.model.get_probs_batch(agent_observation_state_ls)
        chosen_log_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))

        # gold_distribution = cuda_var(torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        # self.cross_entropy = -torch.sum(gold_distribution * torch.log(mini_batch_action_distribution))
        self.entropy = -torch.mean(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(chosen_log_probs) / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        loss = -objective - self.entropy_coef * self.entropy
        self.ratio = torch.abs(objective) / (self.entropy_coef * self.entropy
                                             )  # we want the ratio to be high

        # loss = -objective + self.entropy_coef * self.cross_entropy

        # Minimize the Factor Entropy if the model is implicit factorization model
        if isinstance(self.local_model,
                      IncrementalModelRecurrentImplicitFactorizationResnet):
            self.mean_factor_entropy = torch.mean(torch.cat(factor_entropy))
            loss = loss + self.mean_factor_entropy
        else:
            self.mean_factor_entropy = None

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            self.object_detection_loss = self.constants[
                "object_detection_coeff"] * self.object_detection_loss
            loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        if self.config["do_goal_prediction"]:
            self.goal_prediction_loss = self.goal_prediction_calculator.calc_loss(
                batch_replay_items)
            self.goal_prediction_loss = self.constants["goal_prediction_coeff"] * \
                                        self.goal_prediction_loss
            loss = loss + self.goal_prediction_loss
        else:
            self.goal_prediction_loss = None

        return loss

    @staticmethod
    def convert_text_to_indices(text, vocab, ignore_case=True):

        # Tokenize the text
        token_seq = nltk.word_tokenize(text)

        indices = []

        for token in token_seq:
            if ignore_case:
                ltoken = token.lower()
            else:
                ltoken = token
            if ltoken in vocab:
                indices.append(vocab[ltoken])
            else:
                indices.append(vocab["$UNK$"])

        return indices

    @staticmethod
    def do_train(shared_model,
                 config,
                 action_space,
                 meta_data_util,
                 constants,
                 train_dataset,
                 tune_dataset,
                 experiment,
                 experiment_name,
                 rank,
                 server,
                 logger,
                 model_type,
                 vocab,
                 use_pushover=False):
        try:
            TmpSupervisedLearning.do_train_(shared_model, config, action_space,
                                            meta_data_util, constants,
                                            train_dataset, tune_dataset,
                                            experiment, experiment_name, rank,
                                            server, logger, model_type, vocab,
                                            use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_train_(shared_model,
                  config,
                  action_space,
                  meta_data_util,
                  constants,
                  train_dataset,
                  tune_dataset,
                  experiment,
                  experiment_name,
                  rank,
                  server,
                  logger,
                  model_type,
                  vocab,
                  use_pushover=False):

        print("In training...")

        launch_k_unity_builds([config["port"]],
                              "./simulators/house_3_elmer.x86_64")
        server.initialize_server()
        print("launched builds")

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            # pushover_logger = PushoverLogger(experiment_name)
            pushover_logger = None
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(config, constants)
        # local_model.train()

        # Create the Agent
        logger.log("STARTING AGENT")
        tmp_agent = TmpHouseAgent(server=server,
                                  model=local_model,
                                  test_policy=test_policy,
                                  action_space=action_space,
                                  meta_data_util=meta_data_util,
                                  config=config,
                                  constants=constants)
        logger.log("Created Agent...")

        action_counts = [0] * action_space.num_actions()
        max_epochs = constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        # Create the learner to compute the loss
        learner = TmpSupervisedLearning(shared_model, local_model,
                                        action_space, meta_data_util, config,
                                        constants, tensorboard)
        # TODO change 2 --- unity launch moved up

        for epoch in range(1, max_epochs + 1):

            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_model.load_from_state_dict(shared_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logger.log("Done %d out of %d" %
                               (data_point_ix, dataset_size))
                    logger.log("Training data action counts %r" %
                               action_counts)

                image, metadata = tmp_agent.server.reset_receive_feedback(
                    data_point)
                # instruction = TmpSupervisedLearning.convert_text_to_indices(metadata["instruction"], vocab)
                instruction = data_point.get_instruction()

                # Pose and Orientation gone TODO change 3
                state = AgentObservedState(instruction=instruction,
                                           config=config,
                                           constants=constants,
                                           start_image=image,
                                           previous_action=None,
                                           data_point=data_point)

                model_state = None
                batch_replay_items = []
                total_reward = 0

                # trajectory = metadata["trajectory"]
                trajectory = data_point.get_trajectory()[0:300]

                for action in trajectory:

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, state_feature = \
                        local_model.get_probs(state, model_state)

                    # Sample action from the probability
                    action_counts[action] += 1

                    # Send the action and get feedback
                    image, reward, metadata = tmp_agent.server.send_action_receive_feedback(
                        action)

                    # Store it in the replay memory list
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   reward,
                                                   log_prob=log_probabilities)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    # Pose and orientation gone, TODO change 4
                    state = state.update(image, action, data_point=data_point)

                    total_reward += reward

                # Send final STOP action and get feedback
                # Sample action using the policy
                log_probabilities, model_state, image_emb_seq, state_feature = \
                    local_model.get_probs(state, model_state)
                image, reward, metadata = tmp_agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                # if tensorboard is not None:
                #     tensorboard.log_all_train_errors(
                #         metadata["edit_dist_error"], metadata["closest_dist_error"], metadata["stop_dist_error"])

                # Store it in the replay memory list
                replay_item = ReplayMemoryItem(
                    state,
                    action_space.get_stop_action_index(),
                    reward,
                    log_prob=log_probabilities)
                batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:  # 32
                    loss_val = learner.do_update(batch_replay_items)
                    # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                    # del batch_replay_items[:]  # in place list clear

                    if tensorboard is not None:
                        # cross_entropy = float(learner.cross_entropy.data[0])
                        # tensorboard.log(cross_entropy, loss_val, 0)
                        num_actions = len(trajectory) + 1
                        tensorboard.log_scalar(
                            "loss_val", loss_val)  # /float(num_actions))
                        entropy = float(
                            learner.entropy.data[0])  # /float(num_actions)
                        tensorboard.log_scalar("entropy", entropy)
                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar(
                            "Abs_objective_to_entropy_ratio", ratio)

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(
                                learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(
                                action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(
                                learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(
                                temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(
                                learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(
                                object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(
                                learner.symbolic_language_prediction_loss.
                                data[0])
                            tensorboard.log_scalar(
                                "sym_language_prediction_loss",
                                symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)
                        if learner.mean_factor_entropy is not None:
                            mean_factor_entropy = float(
                                learner.mean_factor_entropy.data[0])
                            tensorboard.log_factor_entropy_loss(
                                mean_factor_entropy)

            # Save the model
            local_model.save_model(experiment + "/contextual_bandit_" +
                                   str(rank) + "_epoch_" + str(epoch))
            logger.log("Training data action counts %r" % action_counts)

            if tune_dataset_size > 0:
                # Test on tuning data
                print("Going for testing")
                tmp_agent.test(tune_dataset,
                               vocab,
                               tensorboard=tensorboard,
                               logger=logger,
                               pushover_logger=pushover_logger)
                print("Done testing")
예제 #7
0
class MultiClientIncrementalDAGGER(AbstractLearning):
    """ Perform DAGGER algorithm of Ross et al.  """
    def __init__(self, model, action_space, meta_data_util, config, constants,
                 tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.num_client = config["num_client"]
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]
        self.beta = 0.9
        self.beta_exp_decay = 0.9
        logging.info(
            "DAGGER: using starting beta of %r and beta exp decay of %r",
            self.beta, self.beta_exp_decay)

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(
                self.model, num_objects=67)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.model)
            self.symbolic_language_prediction_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())  # expert action
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))

        num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        # model_log_prob_batch = self.model.get_probs_batch(agent_observation_state_ls)
        chosen_log_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))

        gold_distribution = cuda_var(
            torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        self.cross_entropy = -torch.sum(
            gold_distribution * torch.log(mini_batch_action_distribution))
        self.entropy = -torch.mean(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(chosen_log_probs) / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        loss = -objective - self.entropy_coef * self.entropy
        # loss = -objective + self.entropy_coef * self.cross_entropy

        # Minimize the Factor Entropy if the model is implicit factorization model
        if isinstance(self.model,
                      IncrementalModelRecurrentImplicitFactorizationResnet):
            self.mean_factor_entropy = torch.mean(torch.cat(factor_entropy))
            loss = loss + self.mean_factor_entropy
        else:
            self.mean_factor_entropy = None

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            self.object_detection_loss = self.constants[
                "object_detection_coeff"] * self.object_detection_loss
            loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        return loss

    @staticmethod
    def get_oracle_length(data_point):
        start_x, start_z, start_angle = data_point.get_start_pos()
        dummy_metadata = {
            "x_pos": start_x,
            "z_pos": start_z,
            "y_angle": start_angle
        }
        goal_x, goal_z = data_point.get_destination_list()[-1]
        oracle_trajectory = get_oracle_trajectory(dummy_metadata, goal_x,
                                                  goal_z, data_point)
        return len(oracle_trajectory)

    def do_train(self, agent, train_dataset, tune_dataset, experiment_name):
        """ Perform training """

        clients = []
        batch_replay_items = []
        for client_ix in range(0, self.num_client):
            client = Client(agent, self.config, self.constants,
                            self.action_space, self.tensorboard, client_ix,
                            batch_replay_items, self.beta)
            clients.append(client)

        dataset_iterator = DatasetIterator(train_dataset)
        epoch = 1
        action_counts = [0] * self.action_space.num_actions()

        if epoch <= self.max_epoch:
            logging.info("Starting epoch %d", epoch)
            # Test on tuning data
            # agent.test(tune_dataset, tensorboard=self.tensorboard)

        probabilities_batch = [None] * self.num_client
        client_state = [None] * self.num_client

        while True:

            for client_ix in range(0, self.num_client):

                client = clients[client_ix]

                # See if the client can progress
                client_status = client.try_to_progress()
                if client_status == Client.WAITING_FOR_EXAMPLE:
                    # Provide the next example
                    data_point = dataset_iterator.get_next()
                    if data_point is None:
                        continue
                    # max_num_actions = len(data_point.get_trajectory())
                    max_num_actions = self.get_oracle_length(data_point)
                    max_num_actions += self.constants["max_extra_horizon"]
                    # max_num_actions = self.constants["horizon"]

                    if self.tensorboard is not None:
                        self.tensorboard.log_scalar("total_reward",
                                                    client.total_reward)
                    client.accept_new_example(data_point, max_num_actions)

                elif client_status == Client.WAITING_FOR_ACTION:

                    # Generate probabilities over actions and take action
                    log_probabilities, new_model_state, image_emb_seq = self.model.get_probs(
                        client.get_state(), client.get_model_state())
                    if isinstance(
                            self.model,
                            IncrementalModelRecurrentImplicitFactorizationResnet
                    ):
                        factor_entropy = self.model.get_recent_factorization_entropy(
                        )
                    else:
                        factor_entropy = None
                    client.take_action(log_probabilities, new_model_state,
                                       image_emb_seq, factor_entropy)
                    # if client_state[client_ix] is None:
                    #     # This client has not waited so make it wait for 1 iteration
                    #     # Take its state and compute the probabiltiy at the end.
                    #     client_state[client_ix] = client.get_state()
                    # else:
                    #     # This client has waited so its probability must be ready.
                    #     probabilities = probabilities_batch[client_ix]
                    #     # Generate probabilities over actions and take action
                    #     # probabilities = list(torch.exp(self.model.get_probs(client.get_state()).data))
                    #     client.take_action(probabilities)
                    #     probabilities_batch[client_ix] = None
                    #     client_state[client_ix] = None

                elif client_status == Client.WAITING_TO_RECEIVE:
                    pass
                else:
                    raise AssertionError("Unknown status. Found " +
                                         str(client_status))

            # states = [state for state in client_state if state is not None]
            # if len(states) > 0:
            #     probabilities = list(torch.exp(self.model.get_probs_batch(states).data))
            #     assert len(states) == len(probabilities)
            #     ctr = 0
            #     for i in range(0, self.num_client):
            #         if client_state[i] is not None:
            #             probabilities_batch[i] = probabilities[ctr]
            #             ctr += 1
            #         else:
            #             probabilities_batch[i] = None

            # Perform update
            if len(batch_replay_items) > 32:
                loss_val = self.do_update(batch_replay_items)
                # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                del batch_replay_items[:]  # in place list clear
                cross_entropy = float(self.cross_entropy.data[0])
                self.tensorboard.log(cross_entropy, loss_val, 0)
                entropy = float(self.entropy.data[0])
                self.tensorboard.log_scalar("entropy", entropy)
                if self.action_prediction_loss is not None:
                    action_prediction_loss = float(
                        self.action_prediction_loss.data[0])
                    self.tensorboard.log_action_prediction_loss(
                        action_prediction_loss)
                if self.temporal_autoencoder_loss is not None:
                    temporal_autoencoder_loss = float(
                        self.temporal_autoencoder_loss.data[0])
                    self.tensorboard.log_temporal_autoencoder_loss(
                        temporal_autoencoder_loss)
                if self.object_detection_loss is not None:
                    object_detection_loss = float(
                        self.object_detection_loss.data[0])
                    self.tensorboard.log_object_detection_loss(
                        object_detection_loss)
                if self.symbolic_language_prediction_loss is not None:
                    symbolic_language_prediction_loss = float(
                        self.symbolic_language_prediction_loss.data[0])
                    self.tensorboard.log_scalar(
                        "sym_language_prediction_loss",
                        symbolic_language_prediction_loss)
                if self.mean_factor_entropy is not None:
                    mean_factor_entropy = float(
                        self.mean_factor_entropy.data[0])
                    self.tensorboard.log_factor_entropy_loss(
                        mean_factor_entropy)

            # Check if an epoch is finished. An epoch is over if all clients are waiting
            # for an example (at which point the iterator also returns none)
            epoch_completed = all([
                client.get_status() == Client.WAITING_FOR_EXAMPLE
                for client in clients
            ])
            if epoch_completed:
                assert dataset_iterator.get_next() is None

                # Reset the iterator
                dataset_iterator.reset()

                # Attenuate the dagger beta value
                self.beta = math.pow(self.beta_exp_decay, epoch)
                for client in clients:
                    client.update_dagger_beta(self.beta)
                logging.info("Attenuated the beta to %r after epoch %r",
                             self.beta, epoch)

                # Save the model
                self.model.save_model(experiment_name + "/dagger_epoch_" +
                                      str(epoch))
                if epoch >= self.max_epoch:
                    break
                epoch += 1
                logging.info("Starting epoch %d", epoch)

                # Test on tuning data
                agent.test(tune_dataset, tensorboard=self.tensorboard)
    def __init__(self, model, action_space, meta_data_util, config, constants,
                 tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.epoch = 0
        self.entropy_coef = constants["entropy_coefficient"]
        self.final_height, self.final_width = 32, 32

        self.ignore_none = True
        self.only_first = True

        self.vocab = {}
        vocab_path = config["vocab_file"]
        word_index = 0
        with open(vocab_path) as f:
            for line in f.readlines():
                token = line.strip()
                self.vocab[token] = word_index
                word_index += 1

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.model,
                num_objects=67,
                camera_angle=60,
                image_height=self.final_height,
                image_width=self.final_width,
                object_height=0)  # -2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(
                self.model, self.final_height, self.final_width)
            self.goal_prediction_loss = None

        self.cross_entropy_loss = None
        self.dist_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)

        logging.info(
            "Created Goal predictor with ignore_none %r and only_first %r",
            self.ignore_none, self.only_first)
예제 #9
0
class AsynchronousAdvantageActorGAECritic(AbstractLearning):
    """ Perform Asynchronous Advantage Actor Critic with Generalized Advantage Estimate """
    def __init__(self, shared_model, local_model, action_space, meta_data_util,
                 config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.gamma = constants["gamma"]
        self.tau = 1.0
        self.shared_model = shared_model
        self.local_model = local_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.ratio = None
        self.value_loss = None
        self.epoch = 0
        self.entropy_coef = constants["entropy_coefficient"]

        self.image_channels, self.image_height, self.image_width = shared_model.image_module.get_final_dimension(
        )

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.local_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.local_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.local_model,
                num_objects=67,
                camera_angle=60,
                image_height=self.image_height,
                image_width=self.image_width,
                object_height=0)  # -2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.local_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(
                self.local_model, self.image_height, self.image_width)
            self.goal_prediction_loss = None

        self.optimizer = optim.Adam(shared_model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_model, self.local_model,
                                  self.calc_loss, self.optimizer, self.config,
                                  self.constants, self.tensorboard)

    def calc_loss(self, batch_replay_items):
        """ Assumes that the batch replay items contains items ordered temporarily """

        agent_observation_state_ls = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        v_values = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())
            v_values.append(replay_item.get_volatile_features()["state_value"])

        # Compute the generalized advantage.
        generalized_advantages = []
        total_reward = []
        last_v_value = None
        sum_reward = 0
        generalized_advantage = cuda_var(torch.zeros(1))
        for replay_item in reversed(batch_replay_items):

            v_value = replay_item.get_volatile_features()["state_value"]

            if last_v_value is None:
                reward = replay_item.get_reward()
                q_val = reward
                advantange = q_val - v_value
                generalized_advantage = advantange
            else:
                reward = replay_item.get_reward()
                q_val = reward + self.gamma * last_v_value
                advantange = q_val - v_value
                generalized_advantage = self.tau * self.gamma * generalized_advantage + advantange

            sum_reward += reward
            last_v_value = v_value
            generalized_advantages.append(generalized_advantage)
            total_reward.append(sum_reward)

        # Reverse the advantages and total reward to temporal order
        generalized_advantages.reverse()
        total_reward.reverse()

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))
        generalized_advantages = torch.cat(generalized_advantages).view(-1)

        total_reward = cuda_var(
            torch.from_numpy(np.array(total_reward)).float()).view(-1)
        v_values = torch.cat(v_values).view(-1)

        model_log_prob_batch = log_probabilities
        chosen_log_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))
        advantage_log_prob = generalized_advantages * chosen_log_probs.view(-1)

        gold_distribution = cuda_var(
            torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        self.value_loss = torch.sum((v_values - total_reward)**2)

        self.cross_entropy = -torch.sum(
            gold_distribution * torch.log(mini_batch_action_distribution))
        # self.entropy = -torch.mean(torch.sum(model_log_prob_batch * model_prob_batch, 1))
        self.entropy = -torch.sum(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(advantage_log_prob)  # / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        entropy_coef = max(0, self.entropy_coef - self.epoch * 0.01)
        loss = -objective - entropy_coef * self.entropy + 0.25 * self.value_loss
        self.ratio = torch.abs(objective) / (entropy_coef * self.entropy
                                             )  # we want the ratio to be high

        # loss = -objective + self.entropy_coef * self.cross_entropy

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            if self.object_detection_loss is not None:
                self.object_detection_loss = self.constants[
                    "object_detection_coeff"] * self.object_detection_loss
                loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        if self.config["do_goal_prediction"]:
            self.goal_prediction_loss, _, _ = self.goal_prediction_calculator.calc_loss(
                batch_replay_items)
            if self.goal_prediction_loss is not None:
                self.goal_prediction_loss = self.constants["goal_prediction_coeff"] * \
                                            self.goal_prediction_loss
                loss = loss + self.goal_prediction_loss  # * len(batch_replay_items)  # scale the loss
        else:
            self.goal_prediction_loss = None

        return loss

    @staticmethod
    def do_train(shared_model,
                 config,
                 action_space,
                 meta_data_util,
                 constants,
                 train_dataset,
                 tune_dataset,
                 experiment,
                 experiment_name,
                 rank,
                 server,
                 logger,
                 model_type,
                 use_pushover=False):
        try:
            AsynchronousAdvantageActorGAECritic.do_train_(
                shared_model, config, action_space, meta_data_util, constants,
                train_dataset, tune_dataset, experiment, experiment_name, rank,
                server, logger, model_type, use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_train_(shared_model,
                  config,
                  action_space,
                  meta_data_util,
                  constants,
                  train_dataset,
                  tune_dataset,
                  experiment,
                  experiment_name,
                  rank,
                  server,
                  logger,
                  model_type,
                  use_pushover=False):

        server.initialize_server()

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = PushoverLogger(experiment_name)
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(config, constants)
        # local_model.train()

        # Create the Agent
        logger.log("STARTING AGENT")
        agent = Agent(server=server,
                      model=local_model,
                      test_policy=test_policy,
                      action_space=action_space,
                      meta_data_util=meta_data_util,
                      config=config,
                      constants=constants)
        logger.log("Created Agent...")

        action_counts = [0] * action_space.num_actions()
        max_epochs = constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        # Create the learner to compute the loss
        learner = AsynchronousAdvantageActorGAECritic(shared_model,
                                                      local_model,
                                                      action_space,
                                                      meta_data_util, config,
                                                      constants, tensorboard)

        # Launch unity
        launch_k_unity_builds([config["port"]],
                              "./simulators/NavDroneLinuxBuild.x86_64")

        for epoch in range(1, max_epochs + 1):

            learner.epoch = epoch
            task_completion_accuracy = 0
            mean_stop_dist_error = 0
            stop_dist_errors = []
            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_model.load_from_state_dict(shared_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logger.log("Done %d out of %d" %
                               (data_point_ix, dataset_size))
                    logger.log("Training data action counts %r" %
                               action_counts)

                num_actions = 0
                max_num_actions = constants["horizon"] + constants[
                    "max_extra_horizon"]

                image, metadata = agent.server.reset_receive_feedback(
                    data_point)

                pose = int(metadata["y_angle"] / 15.0)
                position_orientation = (metadata["x_pos"], metadata["z_pos"],
                                        metadata["y_angle"])
                state = AgentObservedState(
                    instruction=data_point.instruction,
                    config=config,
                    constants=constants,
                    start_image=image,
                    previous_action=None,
                    pose=pose,
                    position_orientation=position_orientation,
                    data_point=data_point)
                state.goal = GoalPrediction.get_goal_location(
                    metadata, data_point, learner.image_height,
                    learner.image_width)

                model_state = None
                batch_replay_items = []
                total_reward = 0
                forced_stop = True

                while num_actions < max_num_actions:

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, volatile = \
                        local_model.get_probs(state, model_state)
                    probabilities = list(torch.exp(log_probabilities.data))[0]

                    # Sample action from the probability
                    action = gp.sample_action_from_prob(probabilities)
                    action_counts[action] += 1

                    # Generate goal
                    if config["do_goal_prediction"]:
                        goal = learner.goal_prediction_calculator.get_goal_location(
                            metadata, data_point, learner.image_height,
                            learner.image_width)
                    else:
                        goal = None

                    if action == action_space.get_stop_action_index():
                        forced_stop = False
                        break

                    # Send the action and get feedback
                    image, reward, metadata = agent.server.send_action_receive_feedback(
                        action)

                    # Store it in the replay memory list
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   reward,
                                                   log_prob=log_probabilities,
                                                   volatile=volatile,
                                                   goal=goal)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    pose = int(metadata["y_angle"] / 15.0)
                    position_orientation = (metadata["x_pos"],
                                            metadata["z_pos"],
                                            metadata["y_angle"])
                    state = state.update(
                        image,
                        action,
                        pose=pose,
                        position_orientation=position_orientation,
                        data_point=data_point)
                    state.goal = GoalPrediction.get_goal_location(
                        metadata, data_point, learner.image_height,
                        learner.image_width)

                    num_actions += 1
                    total_reward += reward

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                if metadata["stop_dist_error"] < 5.0:
                    task_completion_accuracy += 1
                mean_stop_dist_error += metadata["stop_dist_error"]
                stop_dist_errors.append(metadata["stop_dist_error"])

                if tensorboard is not None:
                    tensorboard.log_all_train_errors(
                        metadata["edit_dist_error"],
                        metadata["closest_dist_error"],
                        metadata["stop_dist_error"])

                # Store it in the replay memory list
                if not forced_stop:
                    replay_item = ReplayMemoryItem(
                        state,
                        action_space.get_stop_action_index(),
                        reward,
                        log_prob=log_probabilities,
                        volatile=volatile,
                        goal=goal)
                    batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:  # 32:
                    loss_val = learner.do_update(batch_replay_items)
                    # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                    # del batch_replay_items[:]  # in place list clear

                    if tensorboard is not None:
                        cross_entropy = float(learner.cross_entropy.data[0])
                        tensorboard.log(cross_entropy, loss_val, 0)
                        entropy = float(
                            learner.entropy.data[0]) / float(num_actions + 1)
                        v_value_loss_per_step = float(
                            learner.value_loss.data[0]) / float(num_actions +
                                                                1)
                        tensorboard.log_scalar("entropy", entropy)
                        tensorboard.log_scalar("total_reward", total_reward)
                        tensorboard.log_scalar("v_value_loss_per_step",
                                               v_value_loss_per_step)
                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar(
                            "Abs_objective_to_entropy_ratio", ratio)

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(
                                learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(
                                action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(
                                learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(
                                temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(
                                learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(
                                object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(
                                learner.symbolic_language_prediction_loss.
                                data[0])
                            tensorboard.log_scalar(
                                "sym_language_prediction_loss",
                                symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)

            # Save the model
            local_model.save_model(experiment + "/contextual_bandit_" +
                                   str(rank) + "_epoch_" + str(epoch))
            logger.log("Training data action counts %r" % action_counts)
            mean_stop_dist_error = mean_stop_dist_error / float(
                len(train_dataset))
            task_completion_accuracy = (task_completion_accuracy *
                                        100.0) / float(len(train_dataset))
            logger.log("Training: Mean stop distance error %r" %
                       mean_stop_dist_error)
            logger.log("Training: Task completion accuracy %r " %
                       task_completion_accuracy)
            bins = range(0, 80, 3)  # range of distance
            histogram, _ = np.histogram(stop_dist_errors, bins)
            logger.log("Histogram of train errors %r " % histogram)

            if tune_dataset_size > 0:
                # Test on tuning data
                agent.test(tune_dataset,
                           tensorboard=tensorboard,
                           logger=logger,
                           pushover_logger=pushover_logger)
예제 #10
0
class AsynchronousSupervisedLearning(AbstractLearning):
    """ Perform supervised learning """
    def __init__(self, shared_model, local_model, action_space, meta_data_util,
                 config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_model = shared_model
        self.local_model = local_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.epoch = 0
        self.entropy_coef = constants["entropy_coefficient"]

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.local_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.local_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.local_model,
                num_objects=67,
                camera_angle=60,
                image_height=8,
                image_width=8,
                object_height=0)  #-2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.local_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(self.local_model)
            self.goal_prediction_loss = None

        self.optimizer = optim.Adam(shared_model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_model, self.local_model,
                                  self.calc_loss, self.optimizer, self.config,
                                  self.constants, self.tensorboard)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))

        num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        chosen_log_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))

        gold_distribution = cuda_var(
            torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        self.cross_entropy = -torch.sum(
            gold_distribution * torch.log(mini_batch_action_distribution))
        self.entropy = -torch.mean(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(chosen_log_probs) / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        loss = -objective - self.entropy_coef * self.entropy
        self.ratio = torch.abs(objective) / (self.entropy_coef * self.entropy
                                             )  # we want the ratio to be high

        # loss = -objective + self.entropy_coef * self.cross_entropy

        # Minimize the Factor Entropy if the model is implicit factorization model
        if isinstance(self.local_model,
                      IncrementalModelRecurrentImplicitFactorizationResnet):
            self.mean_factor_entropy = torch.mean(torch.cat(factor_entropy))
            loss = loss + self.mean_factor_entropy
        else:
            self.mean_factor_entropy = None

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            if self.object_detection_loss is not None:
                self.object_detection_loss = self.constants[
                    "object_detection_coeff"] * self.object_detection_loss
                loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        if self.config["do_goal_prediction"]:
            self.goal_prediction_loss, self.goal_prob = self.goal_prediction_calculator.calc_loss(
                batch_replay_items)
            if self.goal_prediction_loss is not None:
                self.goal_prediction_loss = self.constants["goal_prediction_coeff"] * \
                                            self.goal_prediction_loss
                # loss = loss + self.goal_prediction_loss  # * len(batch_replay_items)  # scale the loss
                loss = self.goal_prediction_loss
            else:
                loss = None
        else:
            self.goal_prediction_loss = None

        return loss

    @staticmethod
    def save_goal(batch_replay_items, data_point_ix, trajectory):

        assert len(batch_replay_items) == len(trajectory) + 1
        f = open(
            "../logs/oracle_images/tune_images/example_" + str(data_point_ix) +
            "/goal.txt", "w")
        for item in batch_replay_items:
            row, col, row_real, col_real = item.goal
            f.write("%r %r %r %r\n" % (row, col, row_real, col_real))
        f.flush()
        f.close()

    @staticmethod
    def do_train(shared_model,
                 config,
                 action_space,
                 meta_data_util,
                 constants,
                 train_dataset,
                 tune_dataset,
                 experiment,
                 experiment_name,
                 rank,
                 server,
                 logger,
                 model_type,
                 use_pushover=False):
        try:
            AsynchronousSupervisedLearning.do_train_(
                shared_model, config, action_space, meta_data_util, constants,
                train_dataset, tune_dataset, experiment, experiment_name, rank,
                server, logger, model_type, use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_train_(shared_model,
                  config,
                  action_space,
                  meta_data_util,
                  constants,
                  train_dataset,
                  tune_dataset,
                  experiment,
                  experiment_name,
                  rank,
                  server,
                  logger,
                  model_type,
                  use_pushover=False):

        server.initialize_server()

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = PushoverLogger(experiment_name)
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(config, constants)

        # Create the Agent
        logger.log("STARTING AGENT")
        agent = Agent(server=server,
                      model=local_model,
                      test_policy=test_policy,
                      action_space=action_space,
                      meta_data_util=meta_data_util,
                      config=config,
                      constants=constants)
        logger.log("Created Agent...")

        action_counts = [0] * action_space.num_actions()
        max_epochs = constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        # Create the learner to compute the loss
        learner = AsynchronousSupervisedLearning(shared_model, local_model,
                                                 action_space, meta_data_util,
                                                 config, constants,
                                                 tensorboard)

        # Launch unity
        launch_k_unity_builds([config["port"]],
                              "./simulators/NavDroneLinuxBuild.x86_64")

        for epoch in range(1, max_epochs + 1):

            learner.epoch = epoch

            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_model.load_from_state_dict(shared_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logger.log("Done %d out of %d" %
                               (data_point_ix, dataset_size))
                    logger.log("Training data action counts %r" %
                               action_counts)

                num_actions = 0
                trajectory = data_point.get_trajectory()
                image, metadata = agent.server.reset_receive_feedback(
                    data_point)

                pose = int(metadata["y_angle"] / 15.0)
                position_orientation = (metadata["x_pos"], metadata["z_pos"],
                                        metadata["y_angle"])
                state = AgentObservedState(
                    instruction=data_point.instruction,
                    config=config,
                    constants=constants,
                    start_image=image,
                    previous_action=None,
                    pose=pose,
                    position_orientation=position_orientation,
                    data_point=data_point)

                model_state = None
                batch_replay_items = []
                total_reward = 0

                for action in trajectory:

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, volatile = \
                        local_model.get_probs(state, model_state)

                    action_counts[action] += 1

                    # Generate goal
                    if config["do_goal_prediction"]:
                        goal = learner.goal_prediction_calculator.get_goal_location(
                            metadata, data_point, 8, 8)
                        # learner.goal_prediction_calculator.save_attention_prob(image, volatile)
                        # time.sleep(5)
                    else:
                        goal = None

                    # Send the action and get feedback
                    image, reward, metadata = agent.server.send_action_receive_feedback(
                        action)

                    # Store it in the replay memory list
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   reward,
                                                   log_prob=log_probabilities,
                                                   volatile=volatile,
                                                   goal=goal)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    pose = int(metadata["y_angle"] / 15.0)
                    position_orientation = (metadata["x_pos"],
                                            metadata["z_pos"],
                                            metadata["y_angle"])
                    state = state.update(
                        image,
                        action,
                        pose=pose,
                        position_orientation=position_orientation,
                        data_point=data_point)

                    num_actions += 1
                    total_reward += reward

                # Sample action using the policy
                log_probabilities, model_state, image_emb_seq, volatile = \
                    local_model.get_probs(state, model_state)

                # Generate goal
                if config["do_goal_prediction"]:
                    goal = learner.goal_prediction_calculator.get_goal_location(
                        metadata, data_point, 8, 8)
                    # learner.goal_prediction_calculator.save_attention_prob(image, volatile)
                    # time.sleep(5)
                else:
                    goal = None

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                if tensorboard is not None:
                    tensorboard.log_all_train_errors(
                        metadata["edit_dist_error"],
                        metadata["closest_dist_error"],
                        metadata["stop_dist_error"])

                # Store it in the replay memory list
                replay_item = ReplayMemoryItem(
                    state,
                    action_space.get_stop_action_index(),
                    reward,
                    log_prob=log_probabilities,
                    volatile=volatile,
                    goal=goal)
                batch_replay_items.append(replay_item)

                ###########################################3
                AsynchronousSupervisedLearning.save_goal(
                    batch_replay_items, data_point_ix, trajectory)
                ###########################################3

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:  # 32:
                    loss_val = learner.do_update(batch_replay_items)
                    # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                    # del batch_replay_items[:]  # in place list clear

                    if tensorboard is not None:
                        cross_entropy = float(learner.cross_entropy.data[0])
                        tensorboard.log(cross_entropy, loss_val, 0)
                        entropy = float(
                            learner.entropy.data[0]) / float(num_actions + 1)
                        tensorboard.log_scalar("entropy", entropy)
                        tensorboard.log_scalar("total_reward", total_reward)

                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar(
                            "Abs_objective_to_entropy_ratio", ratio)

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(
                                learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(
                                action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(
                                learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(
                                temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(
                                learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(
                                object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(
                                learner.symbolic_language_prediction_loss.
                                data[0])
                            tensorboard.log_scalar(
                                "sym_language_prediction_loss",
                                symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)
                        if learner.goal_prob is not None:
                            goal_prob = float(learner.goal_prob.data[0])
                            tensorboard.log_scalar("goal_prob", goal_prob)
                        if learner.mean_factor_entropy is not None:
                            mean_factor_entropy = float(
                                learner.mean_factor_entropy.data[0])
                            tensorboard.log_factor_entropy_loss(
                                mean_factor_entropy)

            # Save the model
            local_model.save_model(experiment + "/supervised_learning_" +
                                   str(rank) + "_epoch_" + str(epoch))
            logger.log("Training data action counts %r" % action_counts)

            if tune_dataset_size > 0:
                # Test on tuning data
                agent.test_goal_prediction(tune_dataset,
                                           tensorboard=tensorboard,
                                           logger=logger,
                                           pushover_logger=pushover_logger)
class AsynchronousTwoStageContextualBandit(AbstractLearning):
    """ Perform Contextual Bandit learning (Kakade and Langford (circa 2006) & Misra, Langford and Artzi EMNLP 2017)
    on the two stage model. """
    def __init__(self, shared_navigator_model, local_navigator_model,
                 shared_predictor_model, local_predictor_model, action_space,
                 meta_data_util, config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_navigator_model = shared_navigator_model
        self.local_navigator_model = local_navigator_model
        self.shared_predictor_model = shared_predictor_model
        self.local_predictor_model = local_predictor_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.ratio = None
        self.epoch = 0
        self.entropy_coef = constants["entropy_coefficient"]

        self.image_channels, self.image_height, self.image_width = shared_navigator_model.image_module.get_final_dimension(
        )

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.local_navigator_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.local_navigator_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.local_navigator_model,
                num_objects=67,
                camera_angle=60,
                image_height=self.image_height,
                image_width=self.image_width,
                object_height=0)  # -2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.local_navigator_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(
                self.local_navigator_model, self.image_height,
                self.image_width)
            self.goal_prediction_loss = None

        parameters = self.shared_navigator_model.get_parameters()
        parameters.extend(self.shared_predictor_model.get_parameters())

        self.optimizer = optim.Adam(parameters, lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_navigator_model,
                                  self.local_navigator_model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        immediate_rewards = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        chosen_log_goal_prob = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            immediate_rewards.append(replay_item.get_reward())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())
            chosen_log_goal_prob.append(
                replay_item.get_volatile_features()["goal_sample_prob"])

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))
        immediate_rewards = cuda_var(
            torch.from_numpy(np.array(immediate_rewards)).float())

        num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        # model_log_prob_batch = self.model.get_probs_batch(agent_observation_state_ls)
        chosen_log_action_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))

        # Take the probability of goal generation into account
        chosen_log_goal_prob = torch.cat(chosen_log_goal_prob)

        chosen_log_probs = chosen_log_action_probs.view(
            -1) + chosen_log_goal_prob.view(-1)
        reward_log_probs = immediate_rewards * chosen_log_probs

        gold_distribution = cuda_var(
            torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        self.cross_entropy = -torch.sum(
            gold_distribution * torch.log(mini_batch_action_distribution))
        # self.entropy = -torch.mean(torch.sum(model_log_prob_batch * model_prob_batch, 1))
        self.entropy = -torch.sum(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(reward_log_probs)  # / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        entropy_coef = max(0, self.entropy_coef - self.epoch * 0.01)
        loss = -objective - entropy_coef * self.entropy
        self.ratio = torch.abs(objective) / (entropy_coef * self.entropy
                                             )  # we want the ratio to be high

        # loss = -objective + self.entropy_coef * self.cross_entropy

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            if self.object_detection_loss is not None:
                self.object_detection_loss = self.constants[
                    "object_detection_coeff"] * self.object_detection_loss
                loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        if self.config["do_goal_prediction"]:
            self.goal_prediction_loss, _, _ = self.goal_prediction_calculator.calc_loss(
                batch_replay_items)
            if self.goal_prediction_loss is not None:
                self.goal_prediction_loss = self.constants["goal_prediction_coeff"] * \
                                            self.goal_prediction_loss
                loss = loss + self.goal_prediction_loss  # * len(batch_replay_items)  # scale the loss
        else:
            self.goal_prediction_loss = None

        return loss

    def _sample_goal(self, exploration_image, data_point, panaroma=True):

        state = AgentObservedState(
            instruction=data_point.instruction,
            config=self.config,
            constants=self.constants,
            start_image=exploration_image,
            previous_action=None,
            pose=None,
            position_orientation=data_point.get_start_pos(),
            data_point=data_point)

        volatile = self.local_predictor_model.get_attention_prob(
            state, model_state=None)
        attention_prob = list(
            volatile["attention_probs"].view(-1)[:-1].data.cpu().numpy())
        sampled_ix = gp.sample_action_from_prob(attention_prob)
        sampled_prob = volatile["attention_probs"][sampled_ix]
        #################################################

        # Max pointed about that when inferred ix above is the last value then calculations are buggy. He is right.

        predicted_row = int(sampled_ix / float(192))
        predicted_col = sampled_ix % 192
        screen_pos = (predicted_row, predicted_col)

        if panaroma:
            # Index of the 6 image where the goal is
            region_index = int(predicted_col / 32)
            predicted_col = predicted_col % 32  # Column within that image where the goal is
            pos = data_point.get_start_pos()
            new_pos_angle = GoalPredictionSingle360ImageSupervisedLearningFromDisk.\
                get_new_pos_angle_from_region_index(region_index, pos)
            metadata = {
                "x_pos": pos[0],
                "z_pos": pos[1],
                "y_angle": new_pos_angle
            }
        else:
            pos = data_point.get_start_pos()
            metadata = {"x_pos": pos[0], "z_pos": pos[1], "y_angle": pos[2]}

        row, col = predicted_row + 0.5, predicted_col + 0.5

        start_pos = current_pos_from_metadata(metadata)
        start_pose = current_pose_from_metadata(metadata)

        goal_pos = data_point.get_destination_list()[-1]
        height_drone = 2.5
        x_gen, z_gen = get_inverse_object_position(
            row, col, height_drone, 30, 32, 32,
            (start_pos[0], start_pos[1], start_pose))
        predicted_goal_pos = (x_gen, z_gen)
        x_goal, z_goal = goal_pos

        x_diff = x_gen - x_goal
        z_diff = z_gen - z_goal

        dist = math.sqrt(x_diff * x_diff + z_diff * z_diff)

        return predicted_goal_pos, dist, screen_pos, sampled_prob

    @staticmethod
    def do_train(shared_navigator_model,
                 shared_predictor_model,
                 config,
                 action_space,
                 meta_data_util,
                 constants,
                 train_dataset,
                 tune_dataset,
                 experiment,
                 experiment_name,
                 rank,
                 server,
                 logger,
                 navigator_model_type,
                 predictor_model_type,
                 use_pushover=False):
        try:
            AsynchronousTwoStageContextualBandit.do_train_(
                shared_navigator_model, shared_predictor_model, config,
                action_space, meta_data_util, constants, train_dataset,
                tune_dataset, experiment, experiment_name, rank, server,
                logger, navigator_model_type, predictor_model_type,
                use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_train_(shared_navigator_model,
                  shared_predictor_model,
                  config,
                  action_space,
                  meta_data_util,
                  constants,
                  train_dataset,
                  tune_dataset,
                  experiment,
                  experiment_name,
                  rank,
                  server,
                  logger,
                  navigator_model_type,
                  predictor_model_type,
                  use_pushover=False):

        server.initialize_server()

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = PushoverLogger(experiment_name)
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_predictor_model = predictor_model_type(
            config,
            constants,
            final_model_type="unet-positional-encoding",
            final_dimension=(64, 32, 32 * 6))
        local_navigator_model = navigator_model_type(config, constants)
        # local_model.train()

        # Create the Agent
        logger.log("STARTING AGENT")
        agent = PredictorPlannerAgent(server=server,
                                      predictor_model=local_predictor_model,
                                      model=local_navigator_model,
                                      test_policy=test_policy,
                                      action_space=action_space,
                                      meta_data_util=meta_data_util,
                                      config=config,
                                      constants=constants)
        logger.log("Created Agent...")

        action_counts = [0] * action_space.num_actions()
        max_epochs = constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        # Create the learner to compute the loss
        learner = AsynchronousTwoStageContextualBandit(
            shared_navigator_model, local_navigator_model,
            shared_predictor_model, local_predictor_model, action_space,
            meta_data_util, config, constants, tensorboard)

        # Launch unity
        launch_k_unity_builds([config["port"]],
                              "./simulators/NavDroneLinuxBuild.x86_64")

        for epoch in range(1, max_epochs + 1):

            learner.epoch = epoch
            task_completion_accuracy = 0
            mean_stop_dist_error = 0
            stop_dist_errors = []
            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_navigator_model.load_from_state_dict(
                    shared_navigator_model.get_state_dict())
                local_predictor_model.load_from_state_dict(
                    shared_predictor_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logger.log("Done %d out of %d" %
                               (data_point_ix, dataset_size))
                    logger.log("Training data action counts %r" %
                               action_counts)

                num_actions = 0
                max_num_actions = constants["horizon"] + constants[
                    "max_extra_horizon"]

                image, metadata = agent.server.reset_receive_feedback(
                    data_point)

                # Generate goal probability
                # Test image
                panorama = agent.get_exploration_image()

                # Sample a goal location and compute 3D mapping
                predicted_goal, predictor_error, predicted_pixel, sample_prob = learner._sample_goal(
                    panorama, data_point, panaroma=True)

                pose = int(metadata["y_angle"] / 15.0)
                position_orientation = (metadata["x_pos"], metadata["z_pos"],
                                        metadata["y_angle"])
                state = AgentObservedState(
                    instruction=data_point.instruction,
                    config=config,
                    constants=constants,
                    start_image=image,
                    previous_action=None,
                    pose=pose,
                    position_orientation=position_orientation,
                    data_point=data_point)
                current_bot_location = metadata["x_pos"], metadata["z_pos"]
                current_bot_pose = metadata["y_angle"]
                state.goal = PredictorPlannerAgent.get_goal_location(
                    current_bot_location, current_bot_pose, predicted_goal, 32,
                    32)

                model_state = None
                batch_replay_items = []
                total_reward = 0
                forced_stop = True

                while num_actions < max_num_actions:

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, volatile = \
                        local_navigator_model.get_probs(state, model_state)
                    probabilities = list(torch.exp(log_probabilities.data))[0]

                    # Sample action from the probability
                    action = gp.sample_action_from_prob(probabilities)
                    action_counts[action] += 1

                    if action == action_space.get_stop_action_index():
                        forced_stop = False
                        break

                    # Send the action and get feedback
                    image, reward, metadata = agent.server.send_action_receive_feedback(
                        action)

                    # Store it in the replay memory list
                    volatile["goal_sample_prob"] = sample_prob
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   reward,
                                                   log_prob=log_probabilities,
                                                   volatile=volatile)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    pose = int(metadata["y_angle"] / 15.0)
                    position_orientation = (metadata["x_pos"],
                                            metadata["z_pos"],
                                            metadata["y_angle"])
                    state = state.update(
                        image,
                        action,
                        pose=pose,
                        position_orientation=position_orientation,
                        data_point=data_point)

                    current_bot_location = metadata["x_pos"], metadata["z_pos"]
                    current_bot_pose = metadata["y_angle"]
                    state.goal = PredictorPlannerAgent.get_goal_location(
                        current_bot_location, current_bot_pose, predicted_goal,
                        32, 32)

                    num_actions += 1
                    total_reward += reward

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                if metadata["stop_dist_error"] < 5.0:
                    task_completion_accuracy += 1
                mean_stop_dist_error += metadata["stop_dist_error"]
                stop_dist_errors.append(metadata["stop_dist_error"])

                if tensorboard is not None:
                    tensorboard.log_all_train_errors(
                        metadata["edit_dist_error"],
                        metadata["closest_dist_error"],
                        metadata["stop_dist_error"])

                # Store it in the replay memory list
                if not forced_stop:
                    volatile["goal_sample_prob"] = sample_prob
                    replay_item = ReplayMemoryItem(
                        state,
                        action_space.get_stop_action_index(),
                        reward,
                        log_prob=log_probabilities,
                        volatile=volatile)
                    batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:  # 32:
                    loss_val = learner.do_update(batch_replay_items)
                    # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                    # del batch_replay_items[:]  # in place list clear

                    if tensorboard is not None:
                        tensorboard.log_scalar("gold_sample_prob",
                                               float(sample_prob.data[0]))
                        tensorboard.log_scalar("predicted_error",
                                               predictor_error)
                        cross_entropy = float(learner.cross_entropy.data[0])
                        tensorboard.log(cross_entropy, loss_val, 0)
                        entropy = float(
                            learner.entropy.data[0]) / float(num_actions + 1)
                        tensorboard.log_scalar("entropy", entropy)
                        tensorboard.log_scalar("total_reward", total_reward)

                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar(
                            "Abs_objective_to_entropy_ratio", ratio)

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(
                                learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(
                                action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(
                                learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(
                                temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(
                                learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(
                                object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(
                                learner.symbolic_language_prediction_loss.
                                data[0])
                            tensorboard.log_scalar(
                                "sym_language_prediction_loss",
                                symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)

            # Save the model
            local_navigator_model.save_model(experiment +
                                             "/navigator_contextual_bandit_" +
                                             str(rank) + "_epoch_" +
                                             str(epoch))
            local_predictor_model.save_model(experiment +
                                             "/predictor_contextual_bandit_" +
                                             str(rank) + "_epoch_" +
                                             str(epoch))
            logger.log("Training data action counts %r" % action_counts)
            mean_stop_dist_error = mean_stop_dist_error / float(
                len(train_dataset))
            task_completion_accuracy = (task_completion_accuracy *
                                        100.0) / float(len(train_dataset))
            logger.log("Training: Mean stop distance error %r" %
                       mean_stop_dist_error)
            logger.log("Training: Task completion accuracy %r " %
                       task_completion_accuracy)
            bins = range(0, 80, 3)  # range of distance
            histogram, _ = np.histogram(stop_dist_errors, bins)
            logger.log("Histogram of train errors %r " % histogram)

            if tune_dataset_size > 0:
                # Test on tuning data
                agent.test(tune_dataset,
                           tensorboard=tensorboard,
                           logger=logger,
                           pushover_logger=pushover_logger)
예제 #12
0
class TmpStreetViewAsynchronousContextualBandit(AbstractLearning):
    """ Temp file with modification for streetview corpus.
    Perform Contextual Bandit learning (Kakade and Langford (circa 2006) & Misra, Langford and Artzi EMNLP 2017) """
    def __init__(self, shared_model, local_model, action_space, meta_data_util,
                 config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_model = shared_model
        self.local_model = local_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.ratio = None
        self.epoch = 0
        self.entropy_coef = constants["entropy_coefficient"]

        # self.image_channels, self.image_height, self.image_width = shared_model.image_module.get_final_dimension()

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.local_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.local_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.local_model,
                num_objects=67,
                camera_angle=60,
                image_height=self.image_height,
                image_width=self.image_width,
                object_height=0)  # -2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.local_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(
                self.local_model, self.image_height, self.image_width)
            self.goal_prediction_loss = None

        self.optimizer = optim.Adam(shared_model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_model, self.local_model,
                                  self.calc_loss, self.optimizer, self.config,
                                  self.constants, self.tensorboard)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        immediate_rewards = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            immediate_rewards.append(replay_item.get_reward())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))
        immediate_rewards = cuda_var(
            torch.from_numpy(np.array(immediate_rewards)).float())

        num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        # model_log_prob_batch = self.model.get_probs_batch(agent_observation_state_ls)
        chosen_log_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))
        reward_log_probs = immediate_rewards * chosen_log_probs.view(-1)

        gold_distribution = cuda_var(
            torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        self.cross_entropy = -torch.sum(
            gold_distribution * torch.log(mini_batch_action_distribution))
        # self.entropy = -torch.mean(torch.sum(model_log_prob_batch * model_prob_batch, 1))
        self.entropy = -torch.sum(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(reward_log_probs)  # / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        entropy_coef = max(0, self.entropy_coef - self.epoch * 0.01)
        loss = -objective - entropy_coef * self.entropy
        self.ratio = torch.abs(objective) / (entropy_coef * self.entropy
                                             )  # we want the ratio to be high

        # loss = -objective + self.entropy_coef * self.cross_entropy

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            if self.object_detection_loss is not None:
                self.object_detection_loss = self.constants[
                    "object_detection_coeff"] * self.object_detection_loss
                loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        if self.config["do_goal_prediction"]:
            self.goal_prediction_loss, _, _ = self.goal_prediction_calculator.calc_loss(
                batch_replay_items)
            if self.goal_prediction_loss is not None:
                self.goal_prediction_loss = self.constants["goal_prediction_coeff"] * \
                                            self.goal_prediction_loss
                loss = loss + self.goal_prediction_loss  # * len(batch_replay_items)  # scale the loss
        else:
            self.goal_prediction_loss = None

        return loss

    @staticmethod
    def do_train(shared_model,
                 config,
                 action_space,
                 meta_data_util,
                 constants,
                 train_dataset,
                 tune_dataset,
                 experiment,
                 experiment_name,
                 rank,
                 server,
                 logger,
                 model_type,
                 use_pushover=False):
        try:
            TmpStreetViewAsynchronousContextualBandit.do_train_(
                shared_model, config, action_space, meta_data_util, constants,
                train_dataset, tune_dataset, experiment, experiment_name, rank,
                server, logger, model_type, use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_train_(shared_model,
                  config,
                  action_space,
                  meta_data_util,
                  constants,
                  train_dataset,
                  tune_dataset,
                  experiment,
                  experiment_name,
                  rank,
                  server,
                  logger,
                  model_type,
                  use_pushover=False):

        server.initialize_server()

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = PushoverLogger(experiment_name)
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(config, constants)
        # local_model.train()

        # Create the Agent
        logger.log("STARTING AGENT")
        agent = Agent(server=server,
                      model=local_model,
                      test_policy=test_policy,
                      action_space=action_space,
                      meta_data_util=meta_data_util,
                      config=config,
                      constants=constants)
        logger.log("Created Agent...")

        action_counts = [0] * action_space.num_actions()
        action_rewards = [0.0] * action_space.num_actions()
        max_epochs = constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        # Create the learner to compute the loss
        learner = TmpStreetViewAsynchronousContextualBandit(
            shared_model, local_model, action_space, meta_data_util, config,
            constants, tensorboard)

        for epoch in range(1, max_epochs + 1):

            learner.epoch = epoch
            task_completion_accuracy = 0
            mean_stop_dist_error = 0

            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_model.load_from_state_dict(shared_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logger.log("Done %d out of %d" %
                               (data_point_ix, dataset_size))
                    logger.log("Training data action counts %r" %
                               action_counts)
                    mean_action_reward = [
                        action_sum / max(1.0, action_count) for (
                            action_sum,
                            action_count) in zip(action_rewards, action_counts)
                    ]
                    logger.log("Training data action rewards %r" %
                               mean_action_reward)

                num_actions = 0
                max_num_actions = constants["horizon"] + constants[
                    "max_extra_horizon"]

                image, metadata = agent.server.reset_receive_feedback(
                    data_point)

                state = AgentObservedState(instruction=data_point.instruction,
                                           config=config,
                                           constants=constants,
                                           start_image=image,
                                           previous_action=None,
                                           data_point=data_point)
                # state.goal = GoalPrediction.get_goal_location(metadata, data_point,
                #                                               learner.image_height, learner.image_width)

                model_state = None
                batch_replay_items = []
                total_reward = 0
                forced_stop = True

                while num_actions < max_num_actions:

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, volatile = \
                        local_model.get_probs(state, model_state)
                    probabilities = list(torch.exp(log_probabilities.data))[0]

                    # Sample action from the probability
                    action = gp.sample_action_from_prob(probabilities)
                    action_counts[action] += 1

                    # Generate goal
                    if config["do_goal_prediction"]:
                        goal = learner.goal_prediction_calculator.get_goal_location(
                            metadata, data_point, learner.image_height,
                            learner.image_width)
                    else:
                        goal = None

                    if action == action_space.get_stop_action_index():
                        forced_stop = False
                        break

                    # Send the action and get feedback
                    image, reward, metadata = agent.server.send_action_receive_feedback(
                        action)
                    action_rewards[action] += reward

                    # Store it in the replay memory list
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   reward,
                                                   log_prob=log_probabilities,
                                                   volatile=volatile,
                                                   goal=goal)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    state = state.update(image, action, data_point=data_point)
                    # state.goal = GoalPrediction.get_goal_location(metadata, data_point,
                    #                                               learner.image_height, learner.image_width)

                    num_actions += 1
                    total_reward += reward

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                if metadata["navigation_error"] <= 5.0:
                    task_completion_accuracy += 1
                mean_stop_dist_error += metadata["navigation_error"]

                if tensorboard is not None:
                    tensorboard.log_scalar("navigation_error",
                                           metadata["navigation_error"])

                # Store it in the replay memory list
                if not forced_stop:
                    replay_item = ReplayMemoryItem(
                        state,
                        action_space.get_stop_action_index(),
                        reward,
                        log_prob=log_probabilities,
                        volatile=volatile,
                        goal=goal)
                    batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:  # 32:
                    loss_val = learner.do_update(batch_replay_items)
                    # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                    # del batch_replay_items[:]  # in place list clear

                    if tensorboard is not None:
                        cross_entropy = float(learner.cross_entropy.data[0])
                        tensorboard.log(cross_entropy, loss_val, 0)
                        entropy = float(
                            learner.entropy.data[0]) / float(num_actions + 1)
                        tensorboard.log_scalar("entropy", entropy)
                        tensorboard.log_scalar("total_reward", total_reward)

                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar(
                            "Abs_objective_to_entropy_ratio", ratio)

                        logger.log(
                            "Avg. Entropy %r, Total Reward %r, Rollout Length %r, stop-error %r, ratio %r "
                            % (entropy, total_reward, num_actions + 1,
                               metadata["navigation_error"], ratio))

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(
                                learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(
                                action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(
                                learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(
                                temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(
                                learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(
                                object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(
                                learner.symbolic_language_prediction_loss.
                                data[0])
                            tensorboard.log_scalar(
                                "sym_language_prediction_loss",
                                symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)

            # Save the model
            local_model.save_model(experiment + "/contextual_bandit_" +
                                   str(rank) + "_epoch_" + str(epoch))
            logger.log("Training data action counts %r" % action_counts)
            mean_action_reward = [
                action_sum / max(1.0, action_count)
                for (action_sum,
                     action_count) in zip(action_rewards, action_counts)
            ]
            logger.log("Training data action rewards %r" % mean_action_reward)
            mean_stop_dist_error = mean_stop_dist_error / float(
                len(train_dataset))
            task_completion_accuracy = (task_completion_accuracy *
                                        100.0) / float(len(train_dataset))
            logger.log("Training: Mean stop distance error %r" %
                       mean_stop_dist_error)
            logger.log("Training: Task completion accuracy %r " %
                       task_completion_accuracy)

            if tune_dataset_size > 0:
                logger.log("Evaluating on the tune split")
                # Test on tuning data
                agent.test(tune_dataset,
                           tensorboard=tensorboard,
                           logger=logger,
                           pushover_logger=pushover_logger)

                # Test on as elected train set
                # logger.log("Evaluating on first 50 examples in the train split.")
                # agent.test(train_dataset[0:50], tensorboard=tensorboard,
                #            logger=logger, pushover_logger=pushover_logger)

    @staticmethod
    def do_test(shared_model,
                config,
                action_space,
                meta_data_util,
                constants,
                test_dataset,
                experiment_name,
                rank,
                server,
                logger,
                model_type,
                use_pushover=False):
        try:
            TmpStreetViewAsynchronousContextualBandit.do_test_(
                shared_model, config, action_space, meta_data_util, constants,
                test_dataset, experiment_name, rank, server, logger,
                model_type, use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_test_(shared_model,
                 config,
                 action_space,
                 meta_data_util,
                 constants,
                 test_dataset,
                 experiment_name,
                 rank,
                 server,
                 logger,
                 model_type,
                 use_pushover=False):

        server.initialize_server()

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = PushoverLogger(experiment_name)
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(config, constants)
        # local_model.train()

        # Create the Agent
        logger.log("STARTING AGENT")
        agent = Agent(server=server,
                      model=local_model,
                      test_policy=test_policy,
                      action_space=action_space,
                      meta_data_util=meta_data_util,
                      config=config,
                      constants=constants)
        logger.log("Created Agent...")

        tune_dataset_size = len(test_dataset)

        local_model.load_from_state_dict(shared_model.get_state_dict())

        if tune_dataset_size > 0:
            # Test on tuning data
            agent.test(test_dataset,
                       tensorboard=tensorboard,
                       logger=logger,
                       pushover_logger=pushover_logger)
예제 #13
0
class SupervisedLearningFromDisk(AbstractLearning):
    """ Perform maximum likelihood on oracle trajectories using images stored on disk
    and hence does not need client or server. """
    def __init__(self, model, action_space, meta_data_util, config, constants,
                 tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(self.model)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.model)
            self.symbolic_language_prediction_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))

        num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        # model_log_prob_batch = self.model.get_probs_batch(agent_observation_state_ls)
        chosen_log_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))

        gold_distribution = cuda_var(
            torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        self.cross_entropy = -torch.sum(
            gold_distribution * torch.log(mini_batch_action_distribution))
        self.entropy = -torch.mean(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(chosen_log_probs) / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        loss = -objective  # - self.entropy_coef * self.entropy
        # loss = -objective + self.entropy_coef * self.cross_entropy

        # Minimize the Factor Entropy if the model is implicit factorization model
        if isinstance(self.model,
                      IncrementalModelRecurrentImplicitFactorizationResnet):
            self.mean_factor_entropy = torch.mean(torch.cat(factor_entropy))
            loss = loss + self.mean_factor_entropy
        else:
            self.mean_factor_entropy = None

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            self.object_detection_loss = self.constants[
                "object_detection_coeff"] * self.object_detection_loss
            loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        return loss

    @staticmethod
    def parse(folder_name):

        start = time.time()
        dataset = []
        num_examples = len(os.listdir(folder_name))
        for i in range(0, num_examples):
            example_folder_name = folder_name + "/example_" + str(i)
            image_names = [
                file for file in os.listdir(example_folder_name)
                if file.endswith('.png')
            ]
            num_actions = len(image_names)
            images = []
            for j in range(0, num_actions):
                img = scipy.misc.imread(example_folder_name +
                                        "/image_" + str(j) + ".png").swapaxes(
                                            1, 2).swapaxes(0, 1)
                images.append(img)
            dataset.append(images)
        end = time.time()
        logging.info("Parsed dataset of size %r in time % seconds",
                     len(dataset), (end - start))
        return dataset

    def calc_log_prob(self, tune_dataset, tune_image, tensorboard):

        total_validation_log_probability = 0
        for data_point_ix, data_point in enumerate(tune_dataset):
            tune_image_example = tune_image[data_point_ix]
            image = tune_image_example[0]

            model_state = None
            state = AgentObservedState(instruction=data_point.instruction,
                                       config=self.config,
                                       constants=self.constants,
                                       start_image=image,
                                       previous_action=None,
                                       pose=None,
                                       position_orientation=None,
                                       data_point=data_point)
            trajectory = data_point.get_trajectory()

            validation_log_probability = 0

            for action_ix, action in enumerate(trajectory):
                log_probabilities, model_state, image_emb_seq = self.model.get_probs(
                    state, model_state)
                validation_log_probability += float(
                    log_probabilities.data[0][action])
                image = tune_image_example[action_ix + 1]
                state = state.update(image,
                                     action,
                                     pose=None,
                                     position_orientation=None,
                                     data_point=data_point)

            log_probabilities, model_state, image_emb_seq = self.model.get_probs(
                state, model_state)
            validation_log_probability += float(log_probabilities.data[0][
                self.action_space.get_stop_action_index()])
            mean_validation_log_probability = validation_log_probability / float(
                len(trajectory) + 1)
            tensorboard.log_scalar("Validation_Log_Prob",
                                   mean_validation_log_probability)
            total_validation_log_probability += mean_validation_log_probability
        total_validation_log_probability /= float(max(len(tune_dataset), 1))
        logging.info("Mean Validation Log Prob is %r",
                     total_validation_log_probability)

    def do_train(self, train_dataset, train_images, tune_dataset, tune_images,
                 experiment_name):
        """ Perform training """

        dataset_size = len(train_dataset)

        for epoch in range(1, self.max_epoch + 1):

            logging.info("Starting epoch %d", epoch)

            # Test on tuning data
            self.calc_log_prob(tune_dataset,
                               tune_images,
                               tensorboard=self.tensorboard)

            batch_replay_items = []
            episodes_in_batch = 0

            for data_point_ix, data_point in enumerate(train_dataset):

                if (data_point_ix + 1) % 100 == 0:
                    logging.info("Done %d out of %d", data_point_ix,
                                 dataset_size)

                train_images_example = train_images[data_point_ix]
                image = train_images_example[0]
                symbolic_form = nav_drone_symbolic_instructions.get_nav_drone_symbolic_instruction_segment(
                    data_point)

                model_state = None
                state = AgentObservedState(instruction=data_point.instruction,
                                           config=self.config,
                                           constants=self.constants,
                                           start_image=image,
                                           previous_action=None,
                                           pose=None,
                                           position_orientation=None,
                                           data_point=data_point)

                trajectory = data_point.get_trajectory()
                for action_ix, action in enumerate(trajectory):

                    # Sample action using the policy
                    # Generate probabilities over actions
                    log_probabilities, model_state, image_emb_seq = self.model.get_probs(
                        state, model_state)

                    # Send the action and get feedback
                    image = train_images_example[action_ix + 1]

                    # Store it in the replay memory list
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   0,
                                                   log_prob=log_probabilities,
                                                   symbolic_text=symbolic_form,
                                                   image_emb_seq=image_emb_seq,
                                                   text_emb=model_state[0])
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    state = state.update(image,
                                         action,
                                         pose=None,
                                         position_orientation=None,
                                         data_point=data_point)

                log_probabilities, model_state, image_emb_seq = self.model.get_probs(
                    state, model_state)

                # Store it in the replay memory list
                replay_item = ReplayMemoryItem(
                    state,
                    self.action_space.get_stop_action_index(),
                    0,
                    log_prob=log_probabilities,
                    symbolic_text=symbolic_form,
                    image_emb_seq=image_emb_seq,
                    text_emb=model_state[0])
                batch_replay_items.append(replay_item)

                # Perform update
                episodes_in_batch += 1
                if episodes_in_batch == 1:
                    episodes_in_batch = 0
                    loss_val = self.do_update(batch_replay_items)
                    del batch_replay_items[:]  # in place list clear
                    self.tensorboard.log_scalar("loss", loss_val)
                    cross_entropy = float(self.cross_entropy.data[0])
                    self.tensorboard.log_scalar("cross_entropy", cross_entropy)
                    entropy = float(self.entropy.data[0])
                    self.tensorboard.log_scalar("entropy", entropy)
                    if self.action_prediction_loss is not None:
                        action_prediction_loss = float(
                            self.action_prediction_loss.data[0])
                        self.tensorboard.log_action_prediction_loss(
                            action_prediction_loss)
                    if self.temporal_autoencoder_loss is not None:
                        temporal_autoencoder_loss = float(
                            self.temporal_autoencoder_loss.data[0])
                        self.tensorboard.log_temporal_autoencoder_loss(
                            temporal_autoencoder_loss)
                    if self.object_detection_loss is not None:
                        object_detection_loss = float(
                            self.object_detection_loss.data[0])
                        self.tensorboard.log_object_detection_loss(
                            object_detection_loss)
                    if self.symbolic_language_prediction_loss is not None:
                        symbolic_language_prediction_loss = float(
                            self.symbolic_language_prediction_loss.data[0])
                        self.tensorboard.log_scalar(
                            "sym_language_prediction_loss",
                            symbolic_language_prediction_loss)
                    if self.mean_factor_entropy is not None:
                        mean_factor_entropy = float(
                            self.mean_factor_entropy.data[0])
                        self.tensorboard.log_factor_entropy_loss(
                            mean_factor_entropy)

            # Save the model
            self.model.save_model(experiment_name +
                                  "/contextual_bandit_resnet_epoch_" +
                                  str(epoch))