Пример #1
0
    def __init__(self, model, action_space, meta_data_util, config, constants,
                 tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.num_client = config["num_client"]
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(self.model)
            self.object_detection_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss,
                                  self.optimizer, self.config, self.constants)
    def __init__(self, shared_model, local_model, action_space, meta_data_util, config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_model = shared_model
        self.local_model = local_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]
        self.logger = None

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(self.local_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(self.local_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(self.local_model, num_objects=67)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(self.local_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(self.local_model)
            self.goal_prediction_loss = None

        self.optimizer = optim.Adam(shared_model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_model, self.local_model, self.calc_loss,
                                  self.optimizer, self.config, self.constants, self.tensorboard)
Пример #3
0
    def __init__(self, model, action_space, meta_data_util, config, constants,
                 tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.num_client = config["num_client"]
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]
        self.beta = 0.9
        self.beta_exp_decay = 0.9
        logging.info(
            "DAGGER: using starting beta of %r and beta exp decay of %r",
            self.beta, self.beta_exp_decay)

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(
                self.model, num_objects=67)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.model)
            self.symbolic_language_prediction_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)
class TmpAsynchronousContextualBandit(AbstractLearning):
    """ Perform Contextual Bandit learning (Kakade and Langford (circa 2006) & Misra, Langford and Artzi EMNLP 2017) """

    def __init__(self, shared_model, local_model, action_space, meta_data_util, config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_model = shared_model
        self.local_model = local_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]
        self.logger = None

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(self.local_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(self.local_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(self.local_model, num_objects=67)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(self.local_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(self.local_model)
            self.goal_prediction_loss = None

        self.optimizer = optim.Adam(shared_model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_model, self.local_model, self.calc_loss,
                                  self.optimizer, self.config, self.constants, self.tensorboard)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        immediate_rewards = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            immediate_rewards.append(replay_item.get_reward())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))
        immediate_rewards = cuda_var(torch.from_numpy(np.array(immediate_rewards)).float())

        # self.logger.log("Learning from Log Probabilities is %r " % log_probabilities.data.cpu().numpy())
        # self.logger.log("Learning from Action Batch is %r " % action_batch.data.cpu().numpy())
        # self.logger.log("Learning from Immediate Rewards is %r " % immediate_rewards.data.cpu().numpy())

        # num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        chosen_log_probs = model_log_prob_batch.gather(1, action_batch.view(-1, 1))
        reward_log_probs = immediate_rewards * chosen_log_probs.view(-1)

        # self.logger.log("Learning from Chosen Log Probs is %r " % chosen_log_probs.data.cpu().numpy())
        # self.logger.log("Learning from Reward Log Probs is %r " % reward_log_probs.data.cpu().numpy())

        model_prob_batch = torch.exp(model_log_prob_batch)
        # mini_batch_action_distribution = torch.mean(model_prob_batch, 0)
        # self.cross_entropy = -torch.sum(gold_distribution * torch.log(mini_batch_action_distribution))
        self.entropy = -torch.sum(torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(reward_log_probs)

        # self.logger.log("Objective is %r and entropy is %r and entropy coef is %r " %
        #                 (objective, self.entropy, self.entropy_coef))

        # Essentially we want the objective to increase and cross entropy to decrease
        loss = -objective - self.entropy_coef * self.entropy
        self.ratio = torch.abs(objective)/(self.entropy_coef * self.entropy)  # we want the ratio to be high

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants["action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(batch_replay_items)
            self.object_detection_loss = self.constants["object_detection_coeff"] * self.object_detection_loss
            loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        if self.config["do_goal_prediction"]:
            self.goal_prediction_loss = self.goal_prediction_calculator.calc_loss(batch_replay_items)
            self.goal_prediction_loss = self.constants["goal_prediction_coeff"] * self.goal_prediction_loss
            loss = loss + self.goal_prediction_loss
        else:
            self.goal_prediction_loss = None

        return loss

    @staticmethod
    def convert_text_to_indices(text, vocab, ignore_case=True):

        # Tokenize the text
        print ("instruction ", text)
        token_seq = nltk.word_tokenize(text)

        indices = []

        for token in token_seq:
            if ignore_case:
                ltoken = token.lower()
            else:
                ltoken = token
            if ltoken in vocab:
                indices.append(vocab[ltoken])
            else:
                indices.append(vocab["$UNK$"])

        return indices

    @staticmethod
    def convert_indices_to_text(indices, vocab):
        return " ".join([vocab[index] for index in indices])

    def get_goal(self, metadata):

        if metadata["goal-screen"] is None:
            return None, None, None, None

        left, bottom, depth = metadata["goal-screen"]

        if 0.01 < left < self.config["image_width"] and 0.01 < bottom < self.config["image_height"] and depth > 0.01:

            scaled_left = left / float(self.config["image_width"])
            scaled_top = 1.0 - bottom / float(self.config["image_height"])

            row_real = self.config["num_manipulation_row"] * scaled_top
            col_real = self.config["num_manipulation_col"] * scaled_left
            row, col = round(row_real), round(col_real)

            if row < 0:
                row = 0
            elif row >= self.config["num_manipulation_row"]:
                row = self.config["num_manipulation_row"] - 1
            if col < 0:
                col = 0
            elif col >= self.config["num_manipulation_col"]:
                col = self.config["num_manipulation_col"] - 1

            return row, col, row_real, col_real
        else:
            return None, None, None, None

    @staticmethod
    def do_train(house_id, shared_model, config, action_space, meta_data_util,
                 constants, train_dataset, tune_dataset, experiment,
                 experiment_name, rank, server, logger, model_type, vocab, use_pushover=False):
        try:
            TmpAsynchronousContextualBandit.do_train_(house_id, shared_model, config, action_space, meta_data_util,
                                                      constants, train_dataset, tune_dataset, experiment,
                                                      experiment_name, rank, server, logger, model_type,
                                                      vocab, use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_train_(house_id, shared_model, config, action_space, meta_data_util, constants,
                  train_dataset, tune_dataset, experiment, experiment_name, rank, server,
                  logger, model_type, vocab, use_pushover=False):

        logger.log("In Training...")
        launch_k_unity_builds([config["port"]], "./house_" + str(house_id) + "_elmer.x86_64",
                              arg_str="--config ./AssetsHouse/config" + str(house_id) + ".json",
                              cwd="./simulators/house/")
        logger.log("Launched Builds.")
        server.initialize_server()
        logger.log("Server Initialized.")

        # Test policy
        test_policy = gp.get_argmax_action

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
            logger.log('Created Tensorboard Server.')
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = None
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(config, constants)
        # local_model.train()

        # Create the Agent
        tmp_agent = TmpHouseAgent(server=server,
                                  model=local_model,
                                  test_policy=test_policy,
                                  action_space=action_space,
                                  meta_data_util=meta_data_util,
                                  config=config,
                                  constants=constants)
        logger.log("Created Agent.")

        action_counts = [0] * action_space.num_actions()
        max_epochs = 100000 # constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        if tune_dataset_size > 0:
            # Test on tuning data
            tmp_agent.test(tune_dataset, vocab, tensorboard=tensorboard,
                           logger=logger, pushover_logger=pushover_logger)

        # Create the learner to compute the loss
        learner = TmpAsynchronousContextualBandit(shared_model, local_model, action_space, meta_data_util,
                                                  config, constants, tensorboard)
        # TODO change 2 --- unity launch moved up
        learner.logger = logger

        for epoch in range(1, max_epochs + 1):

            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_model.load_from_state_dict(shared_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logger.log("Done %d out of %d" %(data_point_ix, dataset_size))
                    logger.log("Training data action counts %r" % action_counts)

                num_actions = 0
                max_num_actions = constants["horizon"]
                max_num_actions += constants["max_extra_horizon"]

                image, metadata = tmp_agent.server.reset_receive_feedback(data_point)
                instruction = data_point.get_instruction()
                # instruction_str = TmpAsynchronousContextualBandit.convert_indices_to_text(instruction, vocab)
                # print("Instruction str is ", instruction_str)

                # Pose and Orientation gone TODO change 3
                state = AgentObservedState(instruction=instruction,
                                           config=config,
                                           constants=constants,
                                           start_image=image,
                                           previous_action=None,
                                           data_point=data_point)
                state.goal = learner.get_goal(metadata)

                model_state = None
                batch_replay_items = []
                total_reward = 0
                forced_stop = True

                while num_actions < max_num_actions:

                    # logger.log("Training: Meta Data %r " % metadata)

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, state_feature = \
                        local_model.get_probs(state, model_state)
                    probabilities = list(torch.exp(log_probabilities.data))[0]

                    # Sample action from the probability
                    action = gp.sample_action_from_prob(probabilities)
                    action_counts[action] += 1

                    if action == action_space.get_stop_action_index():
                        forced_stop = False
                        break

                    # Send the action and get feedback
                    image, reward, metadata = tmp_agent.server.send_action_receive_feedback(action)
                    # logger.log("Action is %r, Reward is %r Probability is %r " % (action, reward, probabilities))

                    # Store it in the replay memory list
                    replay_item = ReplayMemoryItem(state, action, reward, log_prob=log_probabilities)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    # Pose and orientation gone, TODO change 4
                    state = state.update(image, action, data_point=data_point)
                    state.goal = learner.get_goal(metadata)

                    num_actions += 1
                    total_reward += reward

                # Send final STOP action and get feedback
                image, reward, metadata = tmp_agent.server.halt_and_receive_feedback()
                total_reward += reward

                # Store it in the replay memory list
                if not forced_stop:
                    # logger.log("Action is Stop, Reward is %r Probability is %r " % (reward, probabilities))
                    replay_item = ReplayMemoryItem(state, action_space.get_stop_action_index(),
                                                   reward, log_prob=log_probabilities)
                    batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:  # 32
                    loss_val = learner.do_update(batch_replay_items)

                    if tensorboard is not None:
                        # cross_entropy = float(learner.cross_entropy.data[0])
                        # tensorboard.log(cross_entropy, loss_val, 0)
                        tensorboard.log_scalar("loss", loss_val)
                        entropy = float(learner.entropy.data[0])/float(num_actions + 1)
                        tensorboard.log_scalar("entropy", entropy)
                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar("Abs_objective_to_entropy_ratio", ratio)
                        tensorboard.log_scalar("total_reward", total_reward)
                        tensorboard.log_scalar("mean navigation error", metadata['mean-navigation-error'])

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(learner.symbolic_language_prediction_loss.data[0])
                            tensorboard.log_scalar("sym_language_prediction_loss", symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss", goal_prediction_loss)

            # Save the model
            local_model.save_model(experiment + "/contextual_bandit_" + str(rank) + "_epoch_" + str(epoch))
            logger.log("Training data action counts %r" % action_counts)

            if tune_dataset_size > 0:
                # Test on tuning data
                tmp_agent.test(tune_dataset, vocab, tensorboard=tensorboard,
                               logger=logger, pushover_logger=pushover_logger)
Пример #5
0
class TmpSupervisedLearning(AbstractLearning):
    """ Perform Supervised Learning """
    def __init__(self, shared_model, local_model, action_space, meta_data_util,
                 config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_model = shared_model
        self.local_model = local_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.local_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.local_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(
                self.local_model, num_objects=67)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.local_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(self.local_model)
            self.goal_prediction_loss = None

        self.optimizer = optim.Adam(shared_model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_model, self.local_model,
                                  self.calc_loss, self.optimizer, self.config,
                                  self.constants, self.tensorboard)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))

        num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        # model_log_prob_batch = self.model.get_probs_batch(agent_observation_state_ls)
        chosen_log_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))

        # gold_distribution = cuda_var(torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        # self.cross_entropy = -torch.sum(gold_distribution * torch.log(mini_batch_action_distribution))
        self.entropy = -torch.mean(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(chosen_log_probs) / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        loss = -objective - self.entropy_coef * self.entropy
        self.ratio = torch.abs(objective) / (self.entropy_coef * self.entropy
                                             )  # we want the ratio to be high

        # loss = -objective + self.entropy_coef * self.cross_entropy

        # Minimize the Factor Entropy if the model is implicit factorization model
        if isinstance(self.local_model,
                      IncrementalModelRecurrentImplicitFactorizationResnet):
            self.mean_factor_entropy = torch.mean(torch.cat(factor_entropy))
            loss = loss + self.mean_factor_entropy
        else:
            self.mean_factor_entropy = None

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            self.object_detection_loss = self.constants[
                "object_detection_coeff"] * self.object_detection_loss
            loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        if self.config["do_goal_prediction"]:
            self.goal_prediction_loss = self.goal_prediction_calculator.calc_loss(
                batch_replay_items)
            self.goal_prediction_loss = self.constants["goal_prediction_coeff"] * \
                                        self.goal_prediction_loss
            loss = loss + self.goal_prediction_loss
        else:
            self.goal_prediction_loss = None

        return loss

    @staticmethod
    def convert_text_to_indices(text, vocab, ignore_case=True):

        # Tokenize the text
        token_seq = nltk.word_tokenize(text)

        indices = []

        for token in token_seq:
            if ignore_case:
                ltoken = token.lower()
            else:
                ltoken = token
            if ltoken in vocab:
                indices.append(vocab[ltoken])
            else:
                indices.append(vocab["$UNK$"])

        return indices

    @staticmethod
    def do_train(shared_model,
                 config,
                 action_space,
                 meta_data_util,
                 constants,
                 train_dataset,
                 tune_dataset,
                 experiment,
                 experiment_name,
                 rank,
                 server,
                 logger,
                 model_type,
                 vocab,
                 use_pushover=False):
        try:
            TmpSupervisedLearning.do_train_(shared_model, config, action_space,
                                            meta_data_util, constants,
                                            train_dataset, tune_dataset,
                                            experiment, experiment_name, rank,
                                            server, logger, model_type, vocab,
                                            use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_train_(shared_model,
                  config,
                  action_space,
                  meta_data_util,
                  constants,
                  train_dataset,
                  tune_dataset,
                  experiment,
                  experiment_name,
                  rank,
                  server,
                  logger,
                  model_type,
                  vocab,
                  use_pushover=False):

        print("In training...")

        launch_k_unity_builds([config["port"]],
                              "./simulators/house_3_elmer.x86_64")
        server.initialize_server()
        print("launched builds")

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            # pushover_logger = PushoverLogger(experiment_name)
            pushover_logger = None
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(config, constants)
        # local_model.train()

        # Create the Agent
        logger.log("STARTING AGENT")
        tmp_agent = TmpHouseAgent(server=server,
                                  model=local_model,
                                  test_policy=test_policy,
                                  action_space=action_space,
                                  meta_data_util=meta_data_util,
                                  config=config,
                                  constants=constants)
        logger.log("Created Agent...")

        action_counts = [0] * action_space.num_actions()
        max_epochs = constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        # Create the learner to compute the loss
        learner = TmpSupervisedLearning(shared_model, local_model,
                                        action_space, meta_data_util, config,
                                        constants, tensorboard)
        # TODO change 2 --- unity launch moved up

        for epoch in range(1, max_epochs + 1):

            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_model.load_from_state_dict(shared_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logger.log("Done %d out of %d" %
                               (data_point_ix, dataset_size))
                    logger.log("Training data action counts %r" %
                               action_counts)

                image, metadata = tmp_agent.server.reset_receive_feedback(
                    data_point)
                # instruction = TmpSupervisedLearning.convert_text_to_indices(metadata["instruction"], vocab)
                instruction = data_point.get_instruction()

                # Pose and Orientation gone TODO change 3
                state = AgentObservedState(instruction=instruction,
                                           config=config,
                                           constants=constants,
                                           start_image=image,
                                           previous_action=None,
                                           data_point=data_point)

                model_state = None
                batch_replay_items = []
                total_reward = 0

                # trajectory = metadata["trajectory"]
                trajectory = data_point.get_trajectory()[0:300]

                for action in trajectory:

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, state_feature = \
                        local_model.get_probs(state, model_state)

                    # Sample action from the probability
                    action_counts[action] += 1

                    # Send the action and get feedback
                    image, reward, metadata = tmp_agent.server.send_action_receive_feedback(
                        action)

                    # Store it in the replay memory list
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   reward,
                                                   log_prob=log_probabilities)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    # Pose and orientation gone, TODO change 4
                    state = state.update(image, action, data_point=data_point)

                    total_reward += reward

                # Send final STOP action and get feedback
                # Sample action using the policy
                log_probabilities, model_state, image_emb_seq, state_feature = \
                    local_model.get_probs(state, model_state)
                image, reward, metadata = tmp_agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                # if tensorboard is not None:
                #     tensorboard.log_all_train_errors(
                #         metadata["edit_dist_error"], metadata["closest_dist_error"], metadata["stop_dist_error"])

                # Store it in the replay memory list
                replay_item = ReplayMemoryItem(
                    state,
                    action_space.get_stop_action_index(),
                    reward,
                    log_prob=log_probabilities)
                batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:  # 32
                    loss_val = learner.do_update(batch_replay_items)
                    # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                    # del batch_replay_items[:]  # in place list clear

                    if tensorboard is not None:
                        # cross_entropy = float(learner.cross_entropy.data[0])
                        # tensorboard.log(cross_entropy, loss_val, 0)
                        num_actions = len(trajectory) + 1
                        tensorboard.log_scalar(
                            "loss_val", loss_val)  # /float(num_actions))
                        entropy = float(
                            learner.entropy.data[0])  # /float(num_actions)
                        tensorboard.log_scalar("entropy", entropy)
                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar(
                            "Abs_objective_to_entropy_ratio", ratio)

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(
                                learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(
                                action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(
                                learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(
                                temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(
                                learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(
                                object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(
                                learner.symbolic_language_prediction_loss.
                                data[0])
                            tensorboard.log_scalar(
                                "sym_language_prediction_loss",
                                symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)
                        if learner.mean_factor_entropy is not None:
                            mean_factor_entropy = float(
                                learner.mean_factor_entropy.data[0])
                            tensorboard.log_factor_entropy_loss(
                                mean_factor_entropy)

            # Save the model
            local_model.save_model(experiment + "/contextual_bandit_" +
                                   str(rank) + "_epoch_" + str(epoch))
            logger.log("Training data action counts %r" % action_counts)

            if tune_dataset_size > 0:
                # Test on tuning data
                print("Going for testing")
                tmp_agent.test(tune_dataset,
                               vocab,
                               tensorboard=tensorboard,
                               logger=logger,
                               pushover_logger=pushover_logger)
                print("Done testing")
Пример #6
0
class MultiClientIncrementalDAGGER(AbstractLearning):
    """ Perform DAGGER algorithm of Ross et al.  """
    def __init__(self, model, action_space, meta_data_util, config, constants,
                 tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.num_client = config["num_client"]
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]
        self.beta = 0.9
        self.beta_exp_decay = 0.9
        logging.info(
            "DAGGER: using starting beta of %r and beta exp decay of %r",
            self.beta, self.beta_exp_decay)

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(
                self.model, num_objects=67)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.model)
            self.symbolic_language_prediction_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())  # expert action
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))

        num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        # model_log_prob_batch = self.model.get_probs_batch(agent_observation_state_ls)
        chosen_log_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))

        gold_distribution = cuda_var(
            torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        self.cross_entropy = -torch.sum(
            gold_distribution * torch.log(mini_batch_action_distribution))
        self.entropy = -torch.mean(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(chosen_log_probs) / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        loss = -objective - self.entropy_coef * self.entropy
        # loss = -objective + self.entropy_coef * self.cross_entropy

        # Minimize the Factor Entropy if the model is implicit factorization model
        if isinstance(self.model,
                      IncrementalModelRecurrentImplicitFactorizationResnet):
            self.mean_factor_entropy = torch.mean(torch.cat(factor_entropy))
            loss = loss + self.mean_factor_entropy
        else:
            self.mean_factor_entropy = None

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            self.object_detection_loss = self.constants[
                "object_detection_coeff"] * self.object_detection_loss
            loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        return loss

    @staticmethod
    def get_oracle_length(data_point):
        start_x, start_z, start_angle = data_point.get_start_pos()
        dummy_metadata = {
            "x_pos": start_x,
            "z_pos": start_z,
            "y_angle": start_angle
        }
        goal_x, goal_z = data_point.get_destination_list()[-1]
        oracle_trajectory = get_oracle_trajectory(dummy_metadata, goal_x,
                                                  goal_z, data_point)
        return len(oracle_trajectory)

    def do_train(self, agent, train_dataset, tune_dataset, experiment_name):
        """ Perform training """

        clients = []
        batch_replay_items = []
        for client_ix in range(0, self.num_client):
            client = Client(agent, self.config, self.constants,
                            self.action_space, self.tensorboard, client_ix,
                            batch_replay_items, self.beta)
            clients.append(client)

        dataset_iterator = DatasetIterator(train_dataset)
        epoch = 1
        action_counts = [0] * self.action_space.num_actions()

        if epoch <= self.max_epoch:
            logging.info("Starting epoch %d", epoch)
            # Test on tuning data
            # agent.test(tune_dataset, tensorboard=self.tensorboard)

        probabilities_batch = [None] * self.num_client
        client_state = [None] * self.num_client

        while True:

            for client_ix in range(0, self.num_client):

                client = clients[client_ix]

                # See if the client can progress
                client_status = client.try_to_progress()
                if client_status == Client.WAITING_FOR_EXAMPLE:
                    # Provide the next example
                    data_point = dataset_iterator.get_next()
                    if data_point is None:
                        continue
                    # max_num_actions = len(data_point.get_trajectory())
                    max_num_actions = self.get_oracle_length(data_point)
                    max_num_actions += self.constants["max_extra_horizon"]
                    # max_num_actions = self.constants["horizon"]

                    if self.tensorboard is not None:
                        self.tensorboard.log_scalar("total_reward",
                                                    client.total_reward)
                    client.accept_new_example(data_point, max_num_actions)

                elif client_status == Client.WAITING_FOR_ACTION:

                    # Generate probabilities over actions and take action
                    log_probabilities, new_model_state, image_emb_seq = self.model.get_probs(
                        client.get_state(), client.get_model_state())
                    if isinstance(
                            self.model,
                            IncrementalModelRecurrentImplicitFactorizationResnet
                    ):
                        factor_entropy = self.model.get_recent_factorization_entropy(
                        )
                    else:
                        factor_entropy = None
                    client.take_action(log_probabilities, new_model_state,
                                       image_emb_seq, factor_entropy)
                    # if client_state[client_ix] is None:
                    #     # This client has not waited so make it wait for 1 iteration
                    #     # Take its state and compute the probabiltiy at the end.
                    #     client_state[client_ix] = client.get_state()
                    # else:
                    #     # This client has waited so its probability must be ready.
                    #     probabilities = probabilities_batch[client_ix]
                    #     # Generate probabilities over actions and take action
                    #     # probabilities = list(torch.exp(self.model.get_probs(client.get_state()).data))
                    #     client.take_action(probabilities)
                    #     probabilities_batch[client_ix] = None
                    #     client_state[client_ix] = None

                elif client_status == Client.WAITING_TO_RECEIVE:
                    pass
                else:
                    raise AssertionError("Unknown status. Found " +
                                         str(client_status))

            # states = [state for state in client_state if state is not None]
            # if len(states) > 0:
            #     probabilities = list(torch.exp(self.model.get_probs_batch(states).data))
            #     assert len(states) == len(probabilities)
            #     ctr = 0
            #     for i in range(0, self.num_client):
            #         if client_state[i] is not None:
            #             probabilities_batch[i] = probabilities[ctr]
            #             ctr += 1
            #         else:
            #             probabilities_batch[i] = None

            # Perform update
            if len(batch_replay_items) > 32:
                loss_val = self.do_update(batch_replay_items)
                # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                del batch_replay_items[:]  # in place list clear
                cross_entropy = float(self.cross_entropy.data[0])
                self.tensorboard.log(cross_entropy, loss_val, 0)
                entropy = float(self.entropy.data[0])
                self.tensorboard.log_scalar("entropy", entropy)
                if self.action_prediction_loss is not None:
                    action_prediction_loss = float(
                        self.action_prediction_loss.data[0])
                    self.tensorboard.log_action_prediction_loss(
                        action_prediction_loss)
                if self.temporal_autoencoder_loss is not None:
                    temporal_autoencoder_loss = float(
                        self.temporal_autoencoder_loss.data[0])
                    self.tensorboard.log_temporal_autoencoder_loss(
                        temporal_autoencoder_loss)
                if self.object_detection_loss is not None:
                    object_detection_loss = float(
                        self.object_detection_loss.data[0])
                    self.tensorboard.log_object_detection_loss(
                        object_detection_loss)
                if self.symbolic_language_prediction_loss is not None:
                    symbolic_language_prediction_loss = float(
                        self.symbolic_language_prediction_loss.data[0])
                    self.tensorboard.log_scalar(
                        "sym_language_prediction_loss",
                        symbolic_language_prediction_loss)
                if self.mean_factor_entropy is not None:
                    mean_factor_entropy = float(
                        self.mean_factor_entropy.data[0])
                    self.tensorboard.log_factor_entropy_loss(
                        mean_factor_entropy)

            # Check if an epoch is finished. An epoch is over if all clients are waiting
            # for an example (at which point the iterator also returns none)
            epoch_completed = all([
                client.get_status() == Client.WAITING_FOR_EXAMPLE
                for client in clients
            ])
            if epoch_completed:
                assert dataset_iterator.get_next() is None

                # Reset the iterator
                dataset_iterator.reset()

                # Attenuate the dagger beta value
                self.beta = math.pow(self.beta_exp_decay, epoch)
                for client in clients:
                    client.update_dagger_beta(self.beta)
                logging.info("Attenuated the beta to %r after epoch %r",
                             self.beta, epoch)

                # Save the model
                self.model.save_model(experiment_name + "/dagger_epoch_" +
                                      str(epoch))
                if epoch >= self.max_epoch:
                    break
                epoch += 1
                logging.info("Starting epoch %d", epoch)

                # Test on tuning data
                agent.test(tune_dataset, tensorboard=self.tensorboard)
Пример #7
0
class MultiClientIncrementalContextualBanditGoalImage(AbstractLearning):
    """ Perform Contextual Bandit learning (Kakade and Langford (circa 2006) & Misra, Langford and Artzi EMNLP 2017) """
    def __init__(self, model, action_space, meta_data_util, config, constants,
                 tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.num_client = config["num_client"]
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(self.model)
            self.object_detection_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss,
                                  self.optimizer, self.config, self.constants)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        immediate_rewards = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            immediate_rewards.append(replay_item.get_reward())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))
        immediate_rewards = cuda_var(
            torch.from_numpy(np.array(immediate_rewards)).float())

        num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        # model_log_prob_batch = self.model.get_probs_batch(agent_observation_state_ls)
        chosen_log_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))
        reward_log_probs = immediate_rewards * chosen_log_probs.view(-1)

        gold_distribution = cuda_var(
            torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        cross_entropy = -torch.sum(
            gold_distribution * torch.log(mini_batch_action_distribution))
        objective = torch.sum(reward_log_probs) / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        loss = -objective + self.entropy_coef * cross_entropy
        self.cross_entropy = cross_entropy

        # Minimize the Factor Entropy if the model is implicit factorization model
        if isinstance(self.model,
                      IncrementalModelRecurrentImplicitFactorizationResnet):
            self.mean_factor_entropy = torch.mean(torch.cat(factor_entropy))
            loss = loss + self.mean_factor_entropy
        else:
            self.mean_factor_entropy = None

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            self.object_detection_loss = self.constants[
                "object_detection_coeff"] * self.object_detection_loss
            loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        return loss

    @staticmethod
    def read_goal_images(dataset, tag):
        dataset_size = len(dataset)
        images = []
        for i in range(0, dataset_size):
            img = scipy.misc.imread("goal_images/" + str(tag) +
                                    "_images/final_image_" + str(i) + ".png")
            images.append(img.swapaxes(1, 2).swapaxes(0, 1))
        return images

    def do_train(self, agent, train_dataset, tune_dataset, experiment_name):
        """ Perform training """

        clients = []
        batch_replay_items = []
        for client_ix in range(0, self.num_client):
            client = Client(agent, self.config, self.constants,
                            self.tensorboard, client_ix, batch_replay_items)
            clients.append(client)

        dataset_iterator = DatasetIterator(train_dataset)
        epoch = 1
        action_counts = [0] * self.action_space.num_actions()

        print("Reading images")
        start = time.time()
        train_images = self.read_goal_images(train_dataset, "train")
        tune_images = self.read_goal_images(tune_dataset, "tune")
        end = time.time()
        print("Read all images. Time taken " + str(end - start) + " seconds. ")

        if epoch <= self.max_epoch:
            logging.info("Starting epoch %d", epoch)
            # Test on tuning data
            agent.test(tune_dataset, tune_images, tensorboard=self.tensorboard)

        probabilities_batch = [None] * self.num_client
        client_state = [None] * self.num_client

        while True:

            for client_ix in range(0, self.num_client):

                client = clients[client_ix]

                # See if the client can progress
                client_status = client.try_to_progress()
                if client_status == Client.WAITING_FOR_EXAMPLE:
                    # Provide the next example
                    data_point = dataset_iterator.get_next()
                    if data_point is None:
                        continue
                    max_num_actions = len(data_point.get_trajectory())
                    max_num_actions += self.constants["max_extra_horizon"]
                    # max_num_actions = self.constants["horizon"]
                    goal_image = train_images[dataset_iterator.datapoint_ix -
                                              1]
                    client.accept_new_example(data_point, max_num_actions,
                                              goal_image)

                elif client_status == Client.WAITING_FOR_ACTION:

                    # Generate probabilities over actions and take action
                    log_probabilities, new_model_state, image_emb_seq = self.model.get_probs(
                        client.get_state(), client.get_model_state())
                    if isinstance(
                            self.model,
                            IncrementalModelRecurrentImplicitFactorizationResnet
                    ):
                        factor_entropy = self.model.get_recent_factorization_entropy(
                        )
                    else:
                        factor_entropy = None
                    client.take_action(log_probabilities, new_model_state,
                                       image_emb_seq, factor_entropy)
                    # if client_state[client_ix] is None:
                    #     # This client has not waited so make it wait for 1 iteration
                    #     # Take its state and compute the probabiltiy at the end.
                    #     client_state[client_ix] = client.get_state()
                    # else:
                    #     # This client has waited so its probability must be ready.
                    #     probabilities = probabilities_batch[client_ix]
                    #     # Generate probabilities over actions and take action
                    #     # probabilities = list(torch.exp(self.model.get_probs(client.get_state()).data))
                    #     client.take_action(probabilities)
                    #     probabilities_batch[client_ix] = None
                    #     client_state[client_ix] = None

                elif client_status == Client.WAITING_TO_RECEIVE:
                    pass
                else:
                    raise AssertionError("Unknown status. Found " +
                                         str(client_status))

            # states = [state for state in client_state if state is not None]
            # if len(states) > 0:
            #     probabilities = list(torch.exp(self.model.get_probs_batch(states).data))
            #     assert len(states) == len(probabilities)
            #     ctr = 0
            #     for i in range(0, self.num_client):
            #         if client_state[i] is not None:
            #             probabilities_batch[i] = probabilities[ctr]
            #             ctr += 1
            #         else:
            #             probabilities_batch[i] = None

            # Perform update
            if len(batch_replay_items) > 32:
                loss_val = self.do_update(batch_replay_items)
                # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                del batch_replay_items[:]  # in place list clear
                # entropy_val = float(self.entropy.data[0])
                # self.tensorboard.log(entropy_val, loss_val, total_reward)
                cross_entropy = float(self.cross_entropy.data[0])
                self.tensorboard.log(cross_entropy, loss_val, 0)
                if self.action_prediction_loss is not None:
                    action_prediction_loss = float(
                        self.action_prediction_loss.data[0])
                    self.tensorboard.log_action_prediction_loss(
                        action_prediction_loss)
                if self.temporal_autoencoder_loss is not None:
                    temporal_autoencoder_loss = float(
                        self.temporal_autoencoder_loss.data[0])
                    self.tensorboard.log_temporal_autoencoder_loss(
                        temporal_autoencoder_loss)
                if self.object_detection_loss is not None:
                    object_detection_loss = float(
                        self.object_detection_loss.data[0])
                    self.tensorboard.log_object_detection_loss(
                        object_detection_loss)
                if self.mean_factor_entropy is not None:
                    mean_factor_entropy = float(
                        self.mean_factor_entropy.data[0])
                    self.tensorboard.log_factor_entropy_loss(
                        mean_factor_entropy)

            # Check if an epoch is finished. An epoch is over if all clients are waiting
            # for an example (at which point the iterator also returns none)
            epoch_completed = all([
                client.get_status() == Client.WAITING_FOR_EXAMPLE
                for client in clients
            ])
            if epoch_completed:
                assert dataset_iterator.get_next() is None

                # Reset the iterator
                dataset_iterator.reset()

                # Save the model
                self.model.save_model(experiment_name +
                                      "/contextual_bandit_resnet_epoch_" +
                                      str(epoch))
                if epoch >= self.max_epoch:
                    break
                epoch += 1
                logging.info("Starting epoch %d", epoch)

                # Test on tuning data
                agent.test(tune_dataset,
                           tune_images,
                           tensorboard=self.tensorboard)
class AsynchronousContextualBandit(AbstractLearning):
    """ Perform expected reward maximization """
    def __init__(self, shared_model, local_model, action_space, meta_data_util,
                 config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_model = shared_model
        self.local_model = local_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.num_client = config["num_client"]
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.local_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.local_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(
                self.local_model, num_objects=67)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.local_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(self.local_model)
            self.goal_prediction_loss = None

        self.optimizer = optim.Adam(shared_model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_model, self.local_model,
                                  self.calc_loss, self.optimizer, self.config,
                                  self.constants, self.tensorboard)

    def calc_loss(self, batch_replay_items):
        agent_observation_state_ls = []
        all_rewards = []
        log_probabilities = []
        factor_entropy = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())
            all_rewards.append(replay_item.get_all_rewards())

        all_rewards = cuda_var(
            torch.from_numpy(np.array(all_rewards)).float())  # batch x action
        log_probabilities = torch.cat(log_probabilities)

        num_states = int(all_rewards.size()[0])
        model_log_prob_batch = log_probabilities
        model_prob_batch = torch.exp(model_log_prob_batch)
        reward_probs = all_rewards * model_prob_batch
        objective = torch.sum(reward_probs) / num_states

        gold_distribution = cuda_var(
            torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)
        self.cross_entropy = -torch.sum(
            gold_distribution * torch.log(mini_batch_action_distribution))
        self.entropy = -torch.mean(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))

        # Essentially we want the objective to increase and cross entropy to decrease
        loss = -objective + self.entropy_coef * self.cross_entropy

        # Minimize the Factor Entropy if the model is implicit factorization model
        if isinstance(self.local_model,
                      IncrementalModelRecurrentImplicitFactorizationResnet):
            self.mean_factor_entropy = torch.mean(torch.cat(factor_entropy))
            loss = loss + self.mean_factor_entropy
        else:
            self.mean_factor_entropy = None

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            self.object_detection_loss = self.constants[
                "object_detection_coeff"] * self.object_detection_loss
            loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        return loss

    def get_all_rewards(self, metadata):
        rewards = []
        for i in range(0, self.config["num_actions"]):
            reward = metadata["reward_dict"][self.action_space.get_action_name(
                i)]
            rewards.append(reward)
        return rewards

    @staticmethod
    def do_train(shared_model,
                 config,
                 action_space,
                 meta_data_util,
                 args,
                 constants,
                 train_dataset,
                 tune_dataset,
                 experiment,
                 experiment_name,
                 rank,
                 server,
                 logger,
                 model_type,
                 use_pushover=False):
        try:
            AsynchronousContextualBandit.do_train_(
                shared_model, config, action_space, meta_data_util, args,
                constants, train_dataset, tune_dataset, experiment,
                experiment_name, rank, server, logger, model_type,
                use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_train_(shared_model,
                  config,
                  action_space,
                  meta_data_util,
                  args,
                  constants,
                  train_dataset,
                  tune_dataset,
                  experiment,
                  experiment_name,
                  rank,
                  server,
                  logger,
                  model_type,
                  use_pushover=False):

        server.initialize_server()

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = PushoverLogger(experiment_name)
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(args, config=config)
        if torch.cuda.is_available():
            local_model.cuda()
        local_model.train()

        # Create the Agent
        logger.log("STARTING AGENT")
        agent = Agent(server=server,
                      model=local_model,
                      test_policy=test_policy,
                      action_space=action_space,
                      meta_data_util=meta_data_util,
                      config=config,
                      constants=constants)
        logger.log("Created Agent...")

        action_counts = [0] * action_space.num_actions()
        max_epochs = constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        # Create the learner to compute the loss
        learner = AsynchronousContextualBandit(shared_model, local_model,
                                               action_space, meta_data_util,
                                               config, constants, tensorboard)

        # Launch unity
        launch_k_unity_builds([
            config["port"]
        ], "/home/dipendra/Downloads/NavDroneLinuxBuild/NavDroneLinuxBuild.x86_64"
                              )

        for epoch in range(1, max_epochs + 1):

            if tune_dataset_size > 0:
                # Test on tuning data
                agent.test(tune_dataset,
                           tensorboard=tensorboard,
                           logger=logger,
                           pushover_logger=pushover_logger)

            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_model.load_from_state_dict(shared_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logging.info("Done %d out of %d", data_point_ix,
                                 dataset_size)
                    logging.info("Training data action counts %r",
                                 action_counts)

                num_actions = 0
                # max_num_actions = len(data_point.get_trajectory())
                # max_num_actions += self.constants["max_extra_horizon"]
                max_num_actions = constants["horizon"]

                image, metadata = agent.server.reset_receive_feedback(
                    data_point)

                pose = int(metadata["y_angle"] / 15.0)
                position_orientation = (metadata["x_pos"], metadata["z_pos"],
                                        metadata["y_angle"])
                state = AgentObservedState(
                    instruction=data_point.instruction,
                    config=config,
                    constants=constants,
                    start_image=image,
                    previous_action=None,
                    pose=pose,
                    position_orientation=position_orientation,
                    data_point=data_point)

                model_state = None
                batch_replay_items = []
                total_reward = 0
                forced_stop = True

                while num_actions < max_num_actions:

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, state_feature = \
                        local_model.get_probs(state, model_state)
                    probabilities = list(torch.exp(log_probabilities.data))[0]

                    # Sample action from the probability
                    action = gp.sample_action_from_prob(probabilities)
                    action_counts[action] += 1

                    if action == action_space.get_stop_action_index():
                        forced_stop = False
                        break

                    # Send the action and get feedback
                    image, reward, metadata = agent.server.send_action_receive_feedback(
                        action)

                    # Store it in the replay memory list
                    rewards = learner.get_all_rewards(metadata)
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   reward,
                                                   log_prob=log_probabilities,
                                                   all_rewards=rewards)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    pose = int(metadata["y_angle"] / 15.0)
                    position_orientation = (metadata["x_pos"],
                                            metadata["z_pos"],
                                            metadata["y_angle"])
                    state = state.update(
                        image,
                        action,
                        pose=pose,
                        position_orientation=position_orientation,
                        data_point=data_point)

                    num_actions += 1
                    total_reward += reward

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                rewards = learner.get_all_rewards(metadata)
                total_reward += reward

                if tensorboard is not None:
                    tensorboard.log_all_train_errors(
                        metadata["edit_dist_error"],
                        metadata["closest_dist_error"],
                        metadata["stop_dist_error"])

                # Store it in the replay memory list
                if not forced_stop:
                    replay_item = ReplayMemoryItem(
                        state,
                        action_space.get_stop_action_index(),
                        reward,
                        log_prob=log_probabilities,
                        all_rewards=rewards)
                    batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:
                    loss_val = learner.do_update(batch_replay_items)
                    # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                    del batch_replay_items[:]  # in place list clear

                    if tensorboard is not None:
                        cross_entropy = float(learner.cross_entropy.data[0])
                        tensorboard.log(cross_entropy, loss_val, 0)
                        entropy = float(learner.entropy.data[0])
                        tensorboard.log_scalar("entropy", entropy)

                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar(
                            "Abs_objective_to_entropy_ratio", ratio)

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(
                                learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(
                                action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(
                                learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(
                                temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(
                                learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(
                                object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(
                                learner.symbolic_language_prediction_loss.
                                data[0])
                            tensorboard.log_scalar(
                                "sym_language_prediction_loss",
                                symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)
                        if learner.mean_factor_entropy is not None:
                            mean_factor_entropy = float(
                                learner.mean_factor_entropy.data[0])
                            tensorboard.log_factor_entropy_loss(
                                mean_factor_entropy)

            # Save the model
            local_model.save_model(experiment + "/contextual_bandit_" +
                                   str(rank) + "_epoch_" + str(epoch))

            logging.info("Training data action counts %r", action_counts)
Пример #9
0
class SupervisedLearningFromDisk(AbstractLearning):
    """ Perform maximum likelihood on oracle trajectories using images stored on disk
    and hence does not need client or server. """
    def __init__(self, model, action_space, meta_data_util, config, constants,
                 tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.entropy_coef = constants["entropy_coefficient"]

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectDetection(self.model)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.model)
            self.symbolic_language_prediction_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))

        num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        # model_log_prob_batch = self.model.get_probs_batch(agent_observation_state_ls)
        chosen_log_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))

        gold_distribution = cuda_var(
            torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        self.cross_entropy = -torch.sum(
            gold_distribution * torch.log(mini_batch_action_distribution))
        self.entropy = -torch.mean(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(chosen_log_probs) / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        loss = -objective  # - self.entropy_coef * self.entropy
        # loss = -objective + self.entropy_coef * self.cross_entropy

        # Minimize the Factor Entropy if the model is implicit factorization model
        if isinstance(self.model,
                      IncrementalModelRecurrentImplicitFactorizationResnet):
            self.mean_factor_entropy = torch.mean(torch.cat(factor_entropy))
            loss = loss + self.mean_factor_entropy
        else:
            self.mean_factor_entropy = None

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            self.object_detection_loss = self.constants[
                "object_detection_coeff"] * self.object_detection_loss
            loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        return loss

    @staticmethod
    def parse(folder_name):

        start = time.time()
        dataset = []
        num_examples = len(os.listdir(folder_name))
        for i in range(0, num_examples):
            example_folder_name = folder_name + "/example_" + str(i)
            image_names = [
                file for file in os.listdir(example_folder_name)
                if file.endswith('.png')
            ]
            num_actions = len(image_names)
            images = []
            for j in range(0, num_actions):
                img = scipy.misc.imread(example_folder_name +
                                        "/image_" + str(j) + ".png").swapaxes(
                                            1, 2).swapaxes(0, 1)
                images.append(img)
            dataset.append(images)
        end = time.time()
        logging.info("Parsed dataset of size %r in time % seconds",
                     len(dataset), (end - start))
        return dataset

    def calc_log_prob(self, tune_dataset, tune_image, tensorboard):

        total_validation_log_probability = 0
        for data_point_ix, data_point in enumerate(tune_dataset):
            tune_image_example = tune_image[data_point_ix]
            image = tune_image_example[0]

            model_state = None
            state = AgentObservedState(instruction=data_point.instruction,
                                       config=self.config,
                                       constants=self.constants,
                                       start_image=image,
                                       previous_action=None,
                                       pose=None,
                                       position_orientation=None,
                                       data_point=data_point)
            trajectory = data_point.get_trajectory()

            validation_log_probability = 0

            for action_ix, action in enumerate(trajectory):
                log_probabilities, model_state, image_emb_seq = self.model.get_probs(
                    state, model_state)
                validation_log_probability += float(
                    log_probabilities.data[0][action])
                image = tune_image_example[action_ix + 1]
                state = state.update(image,
                                     action,
                                     pose=None,
                                     position_orientation=None,
                                     data_point=data_point)

            log_probabilities, model_state, image_emb_seq = self.model.get_probs(
                state, model_state)
            validation_log_probability += float(log_probabilities.data[0][
                self.action_space.get_stop_action_index()])
            mean_validation_log_probability = validation_log_probability / float(
                len(trajectory) + 1)
            tensorboard.log_scalar("Validation_Log_Prob",
                                   mean_validation_log_probability)
            total_validation_log_probability += mean_validation_log_probability
        total_validation_log_probability /= float(max(len(tune_dataset), 1))
        logging.info("Mean Validation Log Prob is %r",
                     total_validation_log_probability)

    def do_train(self, train_dataset, train_images, tune_dataset, tune_images,
                 experiment_name):
        """ Perform training """

        dataset_size = len(train_dataset)

        for epoch in range(1, self.max_epoch + 1):

            logging.info("Starting epoch %d", epoch)

            # Test on tuning data
            self.calc_log_prob(tune_dataset,
                               tune_images,
                               tensorboard=self.tensorboard)

            batch_replay_items = []
            episodes_in_batch = 0

            for data_point_ix, data_point in enumerate(train_dataset):

                if (data_point_ix + 1) % 100 == 0:
                    logging.info("Done %d out of %d", data_point_ix,
                                 dataset_size)

                train_images_example = train_images[data_point_ix]
                image = train_images_example[0]
                symbolic_form = nav_drone_symbolic_instructions.get_nav_drone_symbolic_instruction_segment(
                    data_point)

                model_state = None
                state = AgentObservedState(instruction=data_point.instruction,
                                           config=self.config,
                                           constants=self.constants,
                                           start_image=image,
                                           previous_action=None,
                                           pose=None,
                                           position_orientation=None,
                                           data_point=data_point)

                trajectory = data_point.get_trajectory()
                for action_ix, action in enumerate(trajectory):

                    # Sample action using the policy
                    # Generate probabilities over actions
                    log_probabilities, model_state, image_emb_seq = self.model.get_probs(
                        state, model_state)

                    # Send the action and get feedback
                    image = train_images_example[action_ix + 1]

                    # Store it in the replay memory list
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   0,
                                                   log_prob=log_probabilities,
                                                   symbolic_text=symbolic_form,
                                                   image_emb_seq=image_emb_seq,
                                                   text_emb=model_state[0])
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    state = state.update(image,
                                         action,
                                         pose=None,
                                         position_orientation=None,
                                         data_point=data_point)

                log_probabilities, model_state, image_emb_seq = self.model.get_probs(
                    state, model_state)

                # Store it in the replay memory list
                replay_item = ReplayMemoryItem(
                    state,
                    self.action_space.get_stop_action_index(),
                    0,
                    log_prob=log_probabilities,
                    symbolic_text=symbolic_form,
                    image_emb_seq=image_emb_seq,
                    text_emb=model_state[0])
                batch_replay_items.append(replay_item)

                # Perform update
                episodes_in_batch += 1
                if episodes_in_batch == 1:
                    episodes_in_batch = 0
                    loss_val = self.do_update(batch_replay_items)
                    del batch_replay_items[:]  # in place list clear
                    self.tensorboard.log_scalar("loss", loss_val)
                    cross_entropy = float(self.cross_entropy.data[0])
                    self.tensorboard.log_scalar("cross_entropy", cross_entropy)
                    entropy = float(self.entropy.data[0])
                    self.tensorboard.log_scalar("entropy", entropy)
                    if self.action_prediction_loss is not None:
                        action_prediction_loss = float(
                            self.action_prediction_loss.data[0])
                        self.tensorboard.log_action_prediction_loss(
                            action_prediction_loss)
                    if self.temporal_autoencoder_loss is not None:
                        temporal_autoencoder_loss = float(
                            self.temporal_autoencoder_loss.data[0])
                        self.tensorboard.log_temporal_autoencoder_loss(
                            temporal_autoencoder_loss)
                    if self.object_detection_loss is not None:
                        object_detection_loss = float(
                            self.object_detection_loss.data[0])
                        self.tensorboard.log_object_detection_loss(
                            object_detection_loss)
                    if self.symbolic_language_prediction_loss is not None:
                        symbolic_language_prediction_loss = float(
                            self.symbolic_language_prediction_loss.data[0])
                        self.tensorboard.log_scalar(
                            "sym_language_prediction_loss",
                            symbolic_language_prediction_loss)
                    if self.mean_factor_entropy is not None:
                        mean_factor_entropy = float(
                            self.mean_factor_entropy.data[0])
                        self.tensorboard.log_factor_entropy_loss(
                            mean_factor_entropy)

            # Save the model
            self.model.save_model(experiment_name +
                                  "/contextual_bandit_resnet_epoch_" +
                                  str(epoch))