def __init__(self, shared_navigator_model, local_navigator_model,
                 shared_predictor_model, local_predictor_model, action_space,
                 meta_data_util, config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_navigator_model = shared_navigator_model
        self.local_navigator_model = local_navigator_model
        self.shared_predictor_model = shared_predictor_model
        self.local_predictor_model = local_predictor_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.ratio = None
        self.epoch = 0
        self.entropy_coef = constants["entropy_coefficient"]

        self.image_channels, self.image_height, self.image_width = shared_navigator_model.image_module.get_final_dimension(
        )

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.local_navigator_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.local_navigator_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.local_navigator_model,
                num_objects=67,
                camera_angle=60,
                image_height=self.image_height,
                image_width=self.image_width,
                object_height=0)  # -2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.local_navigator_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(
                self.local_navigator_model, self.image_height,
                self.image_width)
            self.goal_prediction_loss = None

        parameters = self.shared_navigator_model.get_parameters()
        parameters.extend(self.shared_predictor_model.get_parameters())

        self.optimizer = optim.Adam(parameters, lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_navigator_model,
                                  self.local_navigator_model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)
    def __init__(self, model, action_space, meta_data_util, config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.epoch = 0
        self.global_id = 0
        self.entropy_coef = constants["entropy_coefficient"]
        self.final_num_channels, self.final_height, self.final_width = model.image_module.get_final_dimension()

        self.ignore_none = True
        self.inference_procedure = GoalPredictionSingle360ImageSupervisedLearningFromDisk.MODE

        self.vocab = {}
        vocab_path = config["vocab_file"]
        word_index = 0
        with open(vocab_path) as f:
            for line in f.readlines():
                token = line.strip()
                self.vocab[token] = word_index
                word_index += 1

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(self.model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(self.model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.model, num_objects=67, camera_angle=60, image_height=self.final_height,
                image_width=self.final_width, object_height=0)  # -2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(self.model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(self.model, self.final_height, self.final_width)
            self.goal_prediction_loss = None

        self.cross_entropy_loss = None
        self.dist_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss, self.optimizer,
                                  self.config, self.constants, self.tensorboard)

        logging.info("Created Single Image goal predictor with ignore_none %r", self.ignore_none)
Beispiel #3
0
    def __init__(self, model, action_space, meta_data_util, config, constants,
                 tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.epoch = 0
        self.entropy_coef = constants["entropy_coefficient"]

        self.vocab = {}
        vocab_path = config["vocab_file"]
        word_index = 0
        with open(vocab_path) as f:
            for line in f.readlines():
                token = line.strip()
                self.vocab[token] = word_index
                word_index += 1

        # Auxiliary Objectives
        if False:  # self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.model,
                num_objects=67,
                camera_angle=60,
                image_height=self.final_height,
                image_width=self.final_width,
                object_height=0)  # -2.5)
            self.object_detection_loss = None

        self.cross_entropy_loss = None
        self.dist_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)
Beispiel #4
0
class ImagePredictionLearning(AbstractLearning):
    """ Perform goal prediction on single images (as opposed to doing it for sequence)
    stored on disk and hence does not need client or server. """
    def __init__(self, model, action_space, meta_data_util, config, constants,
                 tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.epoch = 0
        self.entropy_coef = constants["entropy_coefficient"]

        self.vocab = {}
        vocab_path = config["vocab_file"]
        word_index = 0
        with open(vocab_path) as f:
            for line in f.readlines():
                token = line.strip()
                self.vocab[token] = word_index
                word_index += 1

        # Auxiliary Objectives
        if False:  # self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.model,
                num_objects=67,
                camera_angle=60,
                image_height=self.final_height,
                image_width=self.final_width,
                object_height=0)  # -2.5)
            self.object_detection_loss = None

        self.cross_entropy_loss = None
        self.dist_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)

    def calc_loss(self, batch_replay_items):

        # Only compute the goal prediction loss

        loss = None
        for replay_item in batch_replay_items:
            image, instruction, gold_image_ix = replay_item
            log_prob = self.model.get_log_prob((image, instruction))
            loss_ = -log_prob[0, gold_image_ix]
            if loss is None:
                loss = loss_
            else:
                loss = loss - loss_

        loss = loss / len(batch_replay_items)

        if False:  # self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            if self.object_detection_loss is not None:
                self.object_detection_loss = self.constants[
                    "object_detection_coeff"] * self.object_detection_loss
                loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        return loss

    @staticmethod
    def get_gold_image(data_point):

        pos = data_point.get_start_pos()
        metadata = {"x_pos": pos[0], "z_pos": pos[1], "y_angle": pos[2]}
        turn_angle = get_turn_angle_from_metadata_datapoint(
            metadata, data_point)

        assert 180.0 >= turn_angle >= -180.0
        if 30.0 >= turn_angle > -30.0:
            ix = 0
        elif 90.0 >= turn_angle > 30.0:
            ix = 1
        elif 150.0 >= turn_angle > 90.0:
            ix = 2
        elif -30 >= turn_angle > -90.0:
            ix = 5
        elif -90.0 >= turn_angle > -150.0:
            ix = 4
        else:
            ix = 3

        return ix

    @staticmethod
    def parse(folder_name, dataset, model):

        start = time.time()

        # Read images
        image_dataset = []
        num_examples = len(os.listdir(folder_name))

        # Read images
        for i in range(0, num_examples):
            example_folder_name = folder_name + "/example_" + str(i)
            images = []
            for ix in range(
                    0, 6):  # panaroma consists of 6 images stitched together
                img = scipy.misc.imread(example_folder_name +
                                        "/image_" + str(ix) + ".png").swapaxes(
                                            1, 2).swapaxes(0, 1)
                images.append(img)
            image_dataset.append(images)

        # Read the goal state. The data for the single image can be
        # directly computed and does not need to be saved.
        image_index_dataset = []
        for data_point in dataset:
            ix = ImagePredictionLearning.get_gold_image(data_point)
            image_index_dataset.append(ix)

        assert len(image_dataset) == len(dataset) and len(
            image_index_dataset) == len(dataset)

        end = time.time()
        logging.info("Parsed dataset of size %r in time % seconds",
                     len(image_dataset), (end - start))

        return image_dataset, image_index_dataset

    def test(self, tune_dataset, tune_image, tune_goal_location, tensorboard):

        total_validation_loss = 0
        total_validation_exact_accuracy = 0
        total_epsilon_accuracy = 0

        for data_point_ix, data_point in enumerate(tune_dataset):
            tune_image_example = tune_image[data_point_ix]
            tune_image_index = tune_goal_location[data_point_ix]

            log_prob = self.model.get_log_prob(
                (tune_image_example, data_point.instruction))

            loss = -log_prob[0, tune_image_index]
            inferred_ix = int(torch.max(log_prob, 1)[1].data.cpu().numpy()[0])

            if tune_image_index == inferred_ix:
                total_validation_exact_accuracy += 1
            if min((tune_image_index - inferred_ix) % 6,
                   (inferred_ix - tune_image_index) % 6) <= 1.0:
                total_epsilon_accuracy += 1

            total_validation_loss += loss

        num_items = len(tune_dataset)
        mean_total_validation_loss = total_validation_loss / float(
            max(num_items, 1))
        mean_total_validation_accuracy = (total_validation_exact_accuracy *
                                          100.0) / float(max(num_items, 1))
        mean_total_epsilon_accuracy = (total_epsilon_accuracy * 100.0) / float(
            max(num_items, 1))

        logging.info(
            "Mean Test result: Num items %r, Loss %r, Acc is %r, Epsilon Accuracy is %r"
            % (num_items, mean_total_validation_loss,
               mean_total_validation_accuracy, mean_total_epsilon_accuracy))

    def do_train(self, train_dataset, train_images, train_image_indices,
                 tune_dataset, tune_images, tune_goal_location,
                 experiment_name):
        """ Perform training """

        dataset_size = len(train_dataset)
        tensorboard = self.tensorboard

        for epoch in range(1, self.max_epoch + 1):

            logging.info("Starting epoch %d", epoch)

            # Test on tuning data
            self.test(tune_dataset,
                      tune_images,
                      tune_goal_location,
                      tensorboard=tensorboard)
            batch_replay_items = []

            for data_point_ix, data_point in enumerate(train_dataset):

                if (data_point_ix + 1) % 100 == 0:
                    logging.info("Done %d out of %d", data_point_ix,
                                 dataset_size)

                # Store it in the replay memory list
                replay_item = (train_images[data_point_ix],
                               data_point.instruction,
                               train_image_indices[data_point_ix])
                batch_replay_items.append(replay_item)

                # Perform update
                if len(batch_replay_items) > 0:
                    loss_val = self.do_update(batch_replay_items)
                    batch_replay_items = []
                    if tensorboard is not None:
                        tensorboard.log_scalar("Loss", loss_val)
                        if self.object_detection_loss is not None:
                            object_detection_loss = float(
                                self.object_detection_loss.data[0])
                            tensorboard.log_scalar("object_detection_loss",
                                                   object_detection_loss)

            # Save the model
            self.model.save_model(experiment_name +
                                  "/goal_prediction_single_supervised_epoch_" +
                                  str(epoch))
class GoalPredictionSupervisedLearningFromDisk(AbstractLearning):
    """ Perform goal prediction on oracle trajectories using images stored on disk
    and hence does not need client or server. """
    def __init__(self, model, action_space, meta_data_util, config, constants,
                 tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.epoch = 0
        self.entropy_coef = constants["entropy_coefficient"]
        self.final_height, self.final_width = 32, 32

        self.ignore_none = True
        self.only_first = True

        self.vocab = {}
        vocab_path = config["vocab_file"]
        word_index = 0
        with open(vocab_path) as f:
            for line in f.readlines():
                token = line.strip()
                self.vocab[token] = word_index
                word_index += 1

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.model,
                num_objects=67,
                camera_angle=60,
                image_height=self.final_height,
                image_width=self.final_width,
                object_height=0)  # -2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(
                self.model, self.final_height, self.final_width)
            self.goal_prediction_loss = None

        self.cross_entropy_loss = None
        self.dist_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)

        logging.info(
            "Created Goal predictor with ignore_none %r and only_first %r",
            self.ignore_none, self.only_first)

    def calc_loss(self, batch_replay_items):

        # Only compute the goal prediction loss
        self.goal_prediction_loss, self.goal_prob, meta = self.goal_prediction_calculator.calc_loss(
            batch_replay_items)
        loss = self.goal_prediction_loss

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            if self.object_detection_loss is not None:
                self.object_detection_loss = self.constants[
                    "object_detection_coeff"] * self.object_detection_loss
                loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        self.cross_entropy_loss = meta["cross_entropy"]
        self.dist_loss = meta["dist_loss"]

        return loss

    @staticmethod
    def parse(folder_name, dataset):

        start = time.time()

        image_dataset = []
        num_examples = len(os.listdir(folder_name))
        for i in range(0, num_examples):
            example_folder_name = folder_name + "/example_" + str(i)
            image_names = [
                file for file in os.listdir(example_folder_name)
                if file.endswith('.png')
            ]
            num_actions = len(image_names)
            images = []
            for j in range(0, num_actions):
                img = scipy.misc.imread(example_folder_name +
                                        "/image_" + str(j) + ".png").swapaxes(
                                            1, 2).swapaxes(0, 1)
                images.append(img)
            image_dataset.append(images)

        # goal_dataset = []
        # num_examples = len(os.listdir(folder_name))
        # for i in range(0, num_examples):
        #     example_folder_name = folder_name + "/example_" + str(i)
        #     lines = open(example_folder_name + "/goal.txt").readlines()
        #     goals = []
        #     for line in lines:
        #         words = line.strip().split()
        #         assert len(words) == 4
        #         if words[0] == "None" or words[1] == "None":
        #             row, col, row_real, col_real = None, None, None, None
        #         else:
        #             row, col, row_real, col_real = int(words[0]), int(words[1]), float(words[2]), float(words[3])
        #             assert 0 <= row < 8 and 0 <= col < 8
        #         goals.append((row, col, row_real, col_real))
        #     goal_dataset.append(goals)
        #
        # assert len(image_dataset) == len(goal_dataset)
        assert len(image_dataset) == len(dataset)

        ####################################
        #  Hack for synythetic data
        goal_dataset = []
        for i in range(0, num_examples):
            data_point = dataset[i]
            pos = data_point.get_start_pos()
            metadata = {"x_pos": pos[0], "z_pos": pos[1], "y_angle": pos[2]}

            goal_location = [
                GoalPrediction.get_goal_location(metadata, data_point, 32, 32)
            ]
            _, _, row, col = goal_location[0]

            start_pos = current_pos_from_metadata(metadata)
            start_pose = current_pose_from_metadata(metadata)

            if row is not None and col is not None:
                goal_pos = data_point.get_destination_list()[-1]
                height_drone = 2.5
                x_gen, y_gen = get_inverse_object_position(
                    row, col, height_drone, 30, 32, 32,
                    (start_pos[0], start_pos[1], start_pose))

            goal_dataset.append(goal_location)
            data_point.trajectory = [0]
            # if len(image_dataset[i]) >= 2:
            #     data_point.trajectory = [0]  # dummy action added
        #####################################

        end = time.time()
        logging.info("Parsed dataset of size %r in time % seconds",
                     len(image_dataset), (end - start))
        return image_dataset, goal_dataset

    def convert_to_id(self, instruction):
        tk_seq = instruction.split()
        token_ids = []
        for tk in tk_seq:
            if tk in self.vocab:
                token_ids.append(self.vocab[tk])
            else:
                print("Out of vocabulary word. Ignoring ", tk)
        return token_ids

    def is_close_enough(self, inferred_ix, row, col):
        predicted_row = int(inferred_ix / float(self.final_width))
        predicted_col = inferred_ix % self.final_width

        row_diff = row - predicted_row
        col_diff = col - predicted_col

        dist = math.sqrt(row_diff * row_diff + col_diff * col_diff)

        max_dim = float(max(self.final_height, self.final_width))
        if dist < 0.1 * max_dim:
            return True
        else:
            return False

    def interactive_shell(self, train_dataset, train_images):

        traj_len = len(train_dataset)
        keep = False
        image_id = 1
        while True:

            # Sample a random dataset
            if not keep:
                ix = random.randint(0, traj_len - 1)
            data_point = train_dataset[ix]
            image = train_images[ix][0]

            # Show the image in pyplot
            plt.imshow(image.swapaxes(0, 1).swapaxes(1, 2))
            plt.ion()
            plt.show()

            # Get the instruction
            print("Enter the instruction below (q or quit to quit)\n")
            print("Sample instruction is ",
                  instruction_to_string(data_point.instruction, self.config))
            while True:
                instruction = input()
                if instruction == "q" or instruction == "quit":
                    break
                elif len(instruction) == 0:
                    print("Enter a non-empty instruction (q or quit to quit)")
                else:
                    break

            instruction_id = self.convert_to_id(instruction)
            state = AgentObservedState(instruction=instruction_id,
                                       config=self.config,
                                       constants=self.constants,
                                       start_image=image,
                                       previous_action=None,
                                       pose=None,
                                       position_orientation=None,
                                       data_point=data_point)

            # Show the attention mask
            _, _, _, \
            volatile = self.model.get_attention_prob(state, model_state=None)

            attention_prob = volatile["attention_probs"][:-1].view(
                self.final_height, self.final_width)
            attention_prob = attention_prob.cpu().data.numpy()
            resized_kernel = scipy.misc.imresize(
                attention_prob,
                (self.config["image_height"], self.config["image_width"]))
            plt.clf()
            plt.title(instruction)
            plt.imshow(image.swapaxes(0, 1).swapaxes(1, 2))
            plt.imshow(resized_kernel, cmap="jet", alpha=0.5)

            print(
                "Enter s to save, k to keep working on this environment, sk to do both. Other key to simply continue"
            )
            key_ = input()
            if key_ == "s":
                plt.savefig("interactive_image_" + str(image_id) + ".png")
                image_id += 1

            if key_ == "k":
                keep = True
            else:
                keep = False

            if key_ == "sk":
                plt.savefig("image_" + str(image_id) + ".png")
                image_id += 1
                keep = True

            plt.clf()

    def test(self, tune_dataset, tune_image, tune_goal_location, tensorboard):

        total_validation_loss = 0
        total_validation_prob = 0
        total_validation_exact_accuracy = 0
        total_goal_distance = 0
        num_items = 0

        # Next metric measures when the goal is visible and prediction is within 10\% radius
        total_epsilon_accuracy = 0
        num_visible_items = 0

        for data_point_ix, data_point in enumerate(tune_dataset):
            tune_image_example = tune_image[data_point_ix]
            goal_location = tune_goal_location[data_point_ix]
            image = tune_image_example[0]

            model_state = None
            state = AgentObservedState(
                instruction=data_point.instruction,
                config=self.config,
                constants=self.constants,
                start_image=image,
                previous_action=None,
                pose=None,
                position_orientation=data_point.get_start_pos(),
                data_point=data_point)
            trajectory = data_point.get_trajectory()
            if self.only_first:
                trajectory = trajectory[0:1]
            traj_len = len(trajectory)
            num_items_ = 0

            sum_loss = 0
            sum_prob = 0
            sum_acc = 0
            sum_dist = 0

            for action_ix, action in enumerate(trajectory):
                state.goal = goal_location[action_ix]
                volatile = self.model.get_attention_prob(state, model_state)
                goal = goal_location[action_ix]
                row, col, _, _ = goal

                if not self.ignore_none or row is not None:
                    if row is None:
                        gold_ix = self.final_height * self.final_width
                    else:
                        gold_ix = row * self.final_width + col
                    loss, prob, meta = GoalPrediction.get_loss_and_prob(
                        volatile, goal, self.final_height, self.final_width)
                    num_items_ += 1
                    sum_loss = sum_loss + float(loss.data.cpu().numpy()[0])
                    sum_prob = sum_prob + float(prob.data.cpu().numpy()[0])

                    inferred_ix = int(
                        torch.max(volatile["attention_logits"],
                                  0)[1].data.cpu().numpy()[0])
                    if gold_ix == inferred_ix:
                        sum_acc = sum_acc + 1.0
                    if row is not None:
                        sum_dist = sum_dist + abs(row - int(round(inferred_ix/self.final_width)))\
                                   + abs(col - int(inferred_ix % self.final_height))

                    if row is not None:
                        num_visible_items += 1
                        if self.is_close_enough(inferred_ix, row, col):
                            total_epsilon_accuracy += 1

                if not self.only_first:
                    image = tune_image_example[action_ix + 1]
                    state = state.update(image,
                                         action,
                                         pose=None,
                                         position_orientation=None,
                                         data_point=data_point)

            if not self.only_first:
                state.goal = goal_location[traj_len]
                volatile = self.model.get_attention_prob(state, model_state)
                goal = goal_location[traj_len]
                row, col, _, _ = goal

                if not self.ignore_none or row is not None:
                    if row is None:
                        gold_ix = self.final_height * self.final_width
                    else:
                        gold_ix = row * self.final_width + col

                    loss, prob, _ = GoalPrediction.get_loss_and_prob(
                        volatile, goal, self.final_height, self.final_width)
                    num_items_ += 1
                    sum_loss = sum_loss + float(loss.data.cpu().numpy()[0])
                    sum_prob = sum_prob + float(prob.data.cpu().numpy()[0])
                    inferred_ix = int(
                        torch.max(volatile["attention_logits"],
                                  0)[1].data.cpu().numpy()[0])
                    if gold_ix == inferred_ix:
                        sum_acc = sum_acc + 1.0
                    if row is not None:
                        sum_dist = sum_dist + abs(row - int(round(inferred_ix/self.final_width))) \
                                   + abs(col - int(inferred_ix % self.final_width))

                    if row is not None:
                        num_visible_items += 1
                        if self.is_close_enough(inferred_ix, row, col):
                            total_epsilon_accuracy += 1

            total_validation_loss += sum_loss
            total_validation_prob += sum_prob
            total_goal_distance += sum_dist
            total_validation_exact_accuracy += sum_acc
            num_items += num_items_

        mean_total_goal_distance = total_goal_distance / float(
            max(num_items, 1))
        mean_total_validation_loss = total_validation_loss / float(
            max(num_items, 1))
        mean_total_validation_prob = total_validation_prob / float(
            max(num_items, 1))
        mean_total_validation_accuracy = (total_validation_exact_accuracy *
                                          100.0) / float(max(num_items, 1))
        mean_total_epsilon_accuracy = (total_epsilon_accuracy * 100.0) / float(
            max(num_visible_items, 1))

        logging.info(
            "Mean Test result: L1 Distance is %r, Loss %r, Prob %r, Acc is %r, Epsilon Accuracy is %r"
            % (mean_total_goal_distance, mean_total_validation_loss,
               mean_total_validation_prob, mean_total_validation_accuracy,
               mean_total_epsilon_accuracy))
        logging.info(
            "Num visible items %r, Num Exact Match items is %r, Num epsilon match %r, Num Items is %r "
            % (num_visible_items, total_validation_exact_accuracy,
               total_epsilon_accuracy, num_items))

    def do_train(self, train_dataset, train_images, train_goal_location,
                 tune_dataset, tune_images, tune_goal_location,
                 experiment_name):
        """ Perform training """

        dataset_size = len(train_dataset)
        tensorboard = self.tensorboard

        for epoch in range(1, self.max_epoch + 1):

            logging.info("Starting epoch %d", epoch)

            # Test on tuning data
            self.test(tune_dataset,
                      tune_images,
                      tune_goal_location,
                      tensorboard=tensorboard)

            for data_point_ix, data_point in enumerate(train_dataset):

                if (data_point_ix + 1) % 100 == 0:
                    logging.info("Done %d out of %d", data_point_ix,
                                 dataset_size)

                train_images_example = train_images[data_point_ix]
                goal_location = train_goal_location[data_point_ix]
                image = train_images_example[0]

                model_state = None
                state = AgentObservedState(
                    instruction=data_point.instruction,
                    config=self.config,
                    constants=self.constants,
                    start_image=image,
                    previous_action=None,
                    pose=None,
                    position_orientation=data_point.get_start_pos(),
                    data_point=data_point)

                trajectory = data_point.get_trajectory()
                traj_len = len(trajectory)
                if self.only_first:
                    trajectory = trajectory[0:1]
                batch_replay_items = []

                for action_ix, action in enumerate(trajectory):

                    # Sample action using the policy
                    # Generate probabilities over actions
                    volatile = self.model.get_attention_prob(
                        state, model_state)
                    goal = goal_location[action_ix]

                    # Store it in the replay memory list
                    if not self.ignore_none or goal[0] is not None:
                        replay_item = ReplayMemoryItem(state,
                                                       action,
                                                       0,
                                                       volatile=volatile,
                                                       goal=goal)
                        batch_replay_items.append(replay_item)

                    if not self.only_first:
                        # Send the action and get feedback
                        image = train_images_example[action_ix + 1]

                        # Update the agent state
                        state = state.update(image,
                                             action,
                                             pose=None,
                                             position_orientation=None,
                                             data_point=data_point)

                # Store it in the replay memory list
                if not self.only_first:
                    goal = goal_location[traj_len]
                    if not self.ignore_none or goal[0] is not None:
                        volatile = self.model.get_attention_prob(
                            state, model_state)
                        replay_item = ReplayMemoryItem(
                            state,
                            self.action_space.get_stop_action_index(),
                            0,
                            volatile=volatile,
                            goal=goal)
                        batch_replay_items.append(replay_item)

                # Perform update
                if len(batch_replay_items) > 0:
                    loss_val = self.do_update(batch_replay_items)
                    if tensorboard is not None:
                        tensorboard.log_scalar("Loss", loss_val)
                        if self.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                self.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)
                        if self.goal_prob is not None:
                            goal_prob = float(self.goal_prob.data[0])
                            tensorboard.log_scalar("goal_prob", goal_prob)
                        if self.object_detection_loss is not None:
                            object_detection_loss = float(
                                self.object_detection_loss.data[0])
                            tensorboard.log_scalar("object_detection_loss",
                                                   object_detection_loss)
                        if self.cross_entropy_loss is not None:
                            cross_entropy_loss = float(
                                self.cross_entropy_loss.data[0])
                            tensorboard.log_scalar("Cross_entropy_loss",
                                                   cross_entropy_loss)
                        if self.dist_loss is not None:
                            dist_loss = float(self.dist_loss.data[0])
                            tensorboard.log_scalar("Dist_loss", dist_loss)

            # Save the model
            self.model.save_model(experiment_name +
                                  "/goal_prediction_supervised_epoch_" +
                                  str(epoch))
Beispiel #6
0
class AsynchronousAdvantageActorGAECritic(AbstractLearning):
    """ Perform Asynchronous Advantage Actor Critic with Generalized Advantage Estimate """
    def __init__(self, shared_model, local_model, action_space, meta_data_util,
                 config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.gamma = constants["gamma"]
        self.tau = 1.0
        self.shared_model = shared_model
        self.local_model = local_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.ratio = None
        self.value_loss = None
        self.epoch = 0
        self.entropy_coef = constants["entropy_coefficient"]

        self.image_channels, self.image_height, self.image_width = shared_model.image_module.get_final_dimension(
        )

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.local_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.local_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.local_model,
                num_objects=67,
                camera_angle=60,
                image_height=self.image_height,
                image_width=self.image_width,
                object_height=0)  # -2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.local_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(
                self.local_model, self.image_height, self.image_width)
            self.goal_prediction_loss = None

        self.optimizer = optim.Adam(shared_model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_model, self.local_model,
                                  self.calc_loss, self.optimizer, self.config,
                                  self.constants, self.tensorboard)

    def calc_loss(self, batch_replay_items):
        """ Assumes that the batch replay items contains items ordered temporarily """

        agent_observation_state_ls = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        v_values = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())
            v_values.append(replay_item.get_volatile_features()["state_value"])

        # Compute the generalized advantage.
        generalized_advantages = []
        total_reward = []
        last_v_value = None
        sum_reward = 0
        generalized_advantage = cuda_var(torch.zeros(1))
        for replay_item in reversed(batch_replay_items):

            v_value = replay_item.get_volatile_features()["state_value"]

            if last_v_value is None:
                reward = replay_item.get_reward()
                q_val = reward
                advantange = q_val - v_value
                generalized_advantage = advantange
            else:
                reward = replay_item.get_reward()
                q_val = reward + self.gamma * last_v_value
                advantange = q_val - v_value
                generalized_advantage = self.tau * self.gamma * generalized_advantage + advantange

            sum_reward += reward
            last_v_value = v_value
            generalized_advantages.append(generalized_advantage)
            total_reward.append(sum_reward)

        # Reverse the advantages and total reward to temporal order
        generalized_advantages.reverse()
        total_reward.reverse()

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))
        generalized_advantages = torch.cat(generalized_advantages).view(-1)

        total_reward = cuda_var(
            torch.from_numpy(np.array(total_reward)).float()).view(-1)
        v_values = torch.cat(v_values).view(-1)

        model_log_prob_batch = log_probabilities
        chosen_log_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))
        advantage_log_prob = generalized_advantages * chosen_log_probs.view(-1)

        gold_distribution = cuda_var(
            torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        self.value_loss = torch.sum((v_values - total_reward)**2)

        self.cross_entropy = -torch.sum(
            gold_distribution * torch.log(mini_batch_action_distribution))
        # self.entropy = -torch.mean(torch.sum(model_log_prob_batch * model_prob_batch, 1))
        self.entropy = -torch.sum(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(advantage_log_prob)  # / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        entropy_coef = max(0, self.entropy_coef - self.epoch * 0.01)
        loss = -objective - entropy_coef * self.entropy + 0.25 * self.value_loss
        self.ratio = torch.abs(objective) / (entropy_coef * self.entropy
                                             )  # we want the ratio to be high

        # loss = -objective + self.entropy_coef * self.cross_entropy

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            if self.object_detection_loss is not None:
                self.object_detection_loss = self.constants[
                    "object_detection_coeff"] * self.object_detection_loss
                loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        if self.config["do_goal_prediction"]:
            self.goal_prediction_loss, _, _ = self.goal_prediction_calculator.calc_loss(
                batch_replay_items)
            if self.goal_prediction_loss is not None:
                self.goal_prediction_loss = self.constants["goal_prediction_coeff"] * \
                                            self.goal_prediction_loss
                loss = loss + self.goal_prediction_loss  # * len(batch_replay_items)  # scale the loss
        else:
            self.goal_prediction_loss = None

        return loss

    @staticmethod
    def do_train(shared_model,
                 config,
                 action_space,
                 meta_data_util,
                 constants,
                 train_dataset,
                 tune_dataset,
                 experiment,
                 experiment_name,
                 rank,
                 server,
                 logger,
                 model_type,
                 use_pushover=False):
        try:
            AsynchronousAdvantageActorGAECritic.do_train_(
                shared_model, config, action_space, meta_data_util, constants,
                train_dataset, tune_dataset, experiment, experiment_name, rank,
                server, logger, model_type, use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_train_(shared_model,
                  config,
                  action_space,
                  meta_data_util,
                  constants,
                  train_dataset,
                  tune_dataset,
                  experiment,
                  experiment_name,
                  rank,
                  server,
                  logger,
                  model_type,
                  use_pushover=False):

        server.initialize_server()

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = PushoverLogger(experiment_name)
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(config, constants)
        # local_model.train()

        # Create the Agent
        logger.log("STARTING AGENT")
        agent = Agent(server=server,
                      model=local_model,
                      test_policy=test_policy,
                      action_space=action_space,
                      meta_data_util=meta_data_util,
                      config=config,
                      constants=constants)
        logger.log("Created Agent...")

        action_counts = [0] * action_space.num_actions()
        max_epochs = constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        # Create the learner to compute the loss
        learner = AsynchronousAdvantageActorGAECritic(shared_model,
                                                      local_model,
                                                      action_space,
                                                      meta_data_util, config,
                                                      constants, tensorboard)

        # Launch unity
        launch_k_unity_builds([config["port"]],
                              "./simulators/NavDroneLinuxBuild.x86_64")

        for epoch in range(1, max_epochs + 1):

            learner.epoch = epoch
            task_completion_accuracy = 0
            mean_stop_dist_error = 0
            stop_dist_errors = []
            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_model.load_from_state_dict(shared_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logger.log("Done %d out of %d" %
                               (data_point_ix, dataset_size))
                    logger.log("Training data action counts %r" %
                               action_counts)

                num_actions = 0
                max_num_actions = constants["horizon"] + constants[
                    "max_extra_horizon"]

                image, metadata = agent.server.reset_receive_feedback(
                    data_point)

                pose = int(metadata["y_angle"] / 15.0)
                position_orientation = (metadata["x_pos"], metadata["z_pos"],
                                        metadata["y_angle"])
                state = AgentObservedState(
                    instruction=data_point.instruction,
                    config=config,
                    constants=constants,
                    start_image=image,
                    previous_action=None,
                    pose=pose,
                    position_orientation=position_orientation,
                    data_point=data_point)
                state.goal = GoalPrediction.get_goal_location(
                    metadata, data_point, learner.image_height,
                    learner.image_width)

                model_state = None
                batch_replay_items = []
                total_reward = 0
                forced_stop = True

                while num_actions < max_num_actions:

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, volatile = \
                        local_model.get_probs(state, model_state)
                    probabilities = list(torch.exp(log_probabilities.data))[0]

                    # Sample action from the probability
                    action = gp.sample_action_from_prob(probabilities)
                    action_counts[action] += 1

                    # Generate goal
                    if config["do_goal_prediction"]:
                        goal = learner.goal_prediction_calculator.get_goal_location(
                            metadata, data_point, learner.image_height,
                            learner.image_width)
                    else:
                        goal = None

                    if action == action_space.get_stop_action_index():
                        forced_stop = False
                        break

                    # Send the action and get feedback
                    image, reward, metadata = agent.server.send_action_receive_feedback(
                        action)

                    # Store it in the replay memory list
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   reward,
                                                   log_prob=log_probabilities,
                                                   volatile=volatile,
                                                   goal=goal)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    pose = int(metadata["y_angle"] / 15.0)
                    position_orientation = (metadata["x_pos"],
                                            metadata["z_pos"],
                                            metadata["y_angle"])
                    state = state.update(
                        image,
                        action,
                        pose=pose,
                        position_orientation=position_orientation,
                        data_point=data_point)
                    state.goal = GoalPrediction.get_goal_location(
                        metadata, data_point, learner.image_height,
                        learner.image_width)

                    num_actions += 1
                    total_reward += reward

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                if metadata["stop_dist_error"] < 5.0:
                    task_completion_accuracy += 1
                mean_stop_dist_error += metadata["stop_dist_error"]
                stop_dist_errors.append(metadata["stop_dist_error"])

                if tensorboard is not None:
                    tensorboard.log_all_train_errors(
                        metadata["edit_dist_error"],
                        metadata["closest_dist_error"],
                        metadata["stop_dist_error"])

                # Store it in the replay memory list
                if not forced_stop:
                    replay_item = ReplayMemoryItem(
                        state,
                        action_space.get_stop_action_index(),
                        reward,
                        log_prob=log_probabilities,
                        volatile=volatile,
                        goal=goal)
                    batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:  # 32:
                    loss_val = learner.do_update(batch_replay_items)
                    # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                    # del batch_replay_items[:]  # in place list clear

                    if tensorboard is not None:
                        cross_entropy = float(learner.cross_entropy.data[0])
                        tensorboard.log(cross_entropy, loss_val, 0)
                        entropy = float(
                            learner.entropy.data[0]) / float(num_actions + 1)
                        v_value_loss_per_step = float(
                            learner.value_loss.data[0]) / float(num_actions +
                                                                1)
                        tensorboard.log_scalar("entropy", entropy)
                        tensorboard.log_scalar("total_reward", total_reward)
                        tensorboard.log_scalar("v_value_loss_per_step",
                                               v_value_loss_per_step)
                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar(
                            "Abs_objective_to_entropy_ratio", ratio)

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(
                                learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(
                                action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(
                                learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(
                                temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(
                                learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(
                                object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(
                                learner.symbolic_language_prediction_loss.
                                data[0])
                            tensorboard.log_scalar(
                                "sym_language_prediction_loss",
                                symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)

            # Save the model
            local_model.save_model(experiment + "/contextual_bandit_" +
                                   str(rank) + "_epoch_" + str(epoch))
            logger.log("Training data action counts %r" % action_counts)
            mean_stop_dist_error = mean_stop_dist_error / float(
                len(train_dataset))
            task_completion_accuracy = (task_completion_accuracy *
                                        100.0) / float(len(train_dataset))
            logger.log("Training: Mean stop distance error %r" %
                       mean_stop_dist_error)
            logger.log("Training: Task completion accuracy %r " %
                       task_completion_accuracy)
            bins = range(0, 80, 3)  # range of distance
            histogram, _ = np.histogram(stop_dist_errors, bins)
            logger.log("Histogram of train errors %r " % histogram)

            if tune_dataset_size > 0:
                # Test on tuning data
                agent.test(tune_dataset,
                           tensorboard=tensorboard,
                           logger=logger,
                           pushover_logger=pushover_logger)
Beispiel #7
0
class BlockGoalPredictionSupervisedLearningFromDisk(AbstractLearning):
    """ Perform goal prediction on single images (as opposed to doing it for sequence)
    stored on disk and hence does not need client or server. """

    CLOCKWISE, BACKSTICH = range(2)
    image_stich = CLOCKWISE

    MODE, MEAN, REALMEAN, MEANAROUNDMODE = range(4)

    def __init__(self, model, action_space, meta_data_util, config, constants,
                 tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.epoch = 0
        self.global_id = 0
        self.entropy_coef = constants["entropy_coefficient"]
        self.final_num_channels, self.final_height, self.final_width = model.image_module.get_final_dimension(
        )

        self.ignore_none = True
        self.inference_procedure = BlockGoalPredictionSupervisedLearningFromDisk.MODE

        self.vocab = {}
        vocab_path = config["vocab_file"]
        word_index = 0
        with open(vocab_path) as f:
            for line in f.readlines():
                token = line.strip()
                self.vocab[token] = word_index
                word_index += 1

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.model,
                num_objects=67,
                camera_angle=60,
                image_height=self.final_height,
                image_width=self.final_width,
                object_height=0)  # -2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(
                self.model, self.final_height, self.final_width)
            self.goal_prediction_loss = None

        self.cross_entropy_loss = None
        self.dist_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)

        logging.info("Created Single Image goal predictor with ignore_none %r",
                     self.ignore_none)

    def calc_loss(self, batch_replay_items):

        # Only compute the goal prediction loss
        # loss = None
        # for replay_item in batch_replay_items:
        #     self.goal_prediction_loss, self.goal_prob, meta = self.goal_prediction_calculator.calc_loss(
        #         [replay_item])
        #     if loss is None:
        #         loss = self.goal_prediction_loss
        #     else:
        #         loss += self.goal_prediction_loss
        #
        # loss = loss / float(len(batch_replay_items))
        self.goal_prediction_loss, self.goal_prob, meta = self.goal_prediction_calculator.calc_loss(
            batch_replay_items)
        loss = self.goal_prediction_loss

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            if self.object_detection_loss is not None:
                self.object_detection_loss = self.constants[
                    "object_detection_coeff"] * self.object_detection_loss
                loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        self.cross_entropy_loss = meta["cross_entropy"]
        self.dist_loss = meta["dist_loss"]

        return loss

    @staticmethod
    def parse(folder_name, dataset, vocab, debug=False):

        start = time.time()

        with open(folder_name + "/dataset_goal.json") as f:

            data = json.load(f)

            for datapoint in data:

                # Read dataset information
                i = int(datapoint["id"])
                screen_width = datapoint["screenWidth"]
                screen_height = datapoint["screenHeight"]
                instr_str = datapoint["instruction"]
                gold_block_id = int(datapoint["goldBlockId"])

                start_loc_json = datapoint["startLocation"]
                start_loc_str = [
                    start_loc_json["x"], start_loc_json["y"],
                    start_loc_json["z"]
                ]
                start_loc = [float(w) for w in start_loc_str]

                goal_loc_json = datapoint["goalLocation"]
                goal_loc_str = [
                    goal_loc_json["x"], goal_loc_json["y"], goal_loc_json["z"]
                ]
                goal_loc = [float(w) for w in goal_loc_str]

                goal_pixel_json = datapoint["goalPixel"]
                goal_pixel_str = [
                    goal_pixel_json["x"], goal_pixel_json["y"],
                    goal_pixel_json["z"]
                ]
                goal_pixel = [float(w) for w in goal_pixel_str]

                # Read the image
                image = np.load(folder_name + "/example_%s/image.npy" % i)
                image = image.swapaxes(0, 1).swapaxes(1, 2)
                image = np.rot90(image, k=2)
                image = np.fliplr(image)

                # Read the goal information
                lines = open(folder_name +
                             "/example_%s/instruction.txt" % i).readlines()
                assert len(lines) == 2
                instruction = lines[0]

                pixel = ((screen_height - goal_pixel[1]) /
                         float(screen_height),
                         (goal_pixel[0]) / float(screen_width))

                if debug:
                    # save the image for debugging
                    pixel_row_real = int(
                        128 * pixel[0])  # the additional slack is for boundary
                    pixel_col_real = int(128 * pixel[1])

                    if pixel_row_real < 0 or pixel_row_real >= 128 or pixel_col_real < 0 or pixel_col_real >= 128:
                        raise AssertionError("failed")

                    goal = np.zeros((128, 128))
                    for i1 in range(-3, 3):
                        for j1 in range(-3, 3):
                            if pixel_row_real + i1 < 0 or pixel_row_real + i1 >= 128:
                                continue
                            if pixel_col_real + j1 < 0 or pixel_col_real + j1 >= 128:
                                continue
                            goal[pixel_row_real + i1][pixel_col_real +
                                                      j1] = 1.0

                    f, axarr = plt.subplots(1, 2)
                    if instruction is not None:
                        f.suptitle(instruction)
                    axarr[0].imshow(image)
                    axarr[0].imshow(goal, cmap='jet', alpha=0.5)
                    plt.savefig("./goal_block/image_" + str(i) + ".png")
                    plt.clf()

                instruction_indices = TmpAsynchronousContextualBandit.convert_text_to_indices(
                    instruction, vocab)

                python_datapoint = dataset[i - 1]
                python_datapoint.set_instruction(instruction_indices,
                                                 instruction)
                python_datapoint.set_start_image(
                    image.swapaxes(1, 2).swapaxes(0, 1))
                python_datapoint.set_block_id(gold_block_id)
                python_datapoint.set_start_location(start_loc)
                python_datapoint.set_goal_location(goal_loc)

                scaled_pixel_row, scaled_pixel_col = int(pixel[0] * 32), int(
                    pixel[1] * 32)
                if scaled_pixel_row < 0:
                    scaled_pixel_row = 0
                elif scaled_pixel_row >= 32:
                    scaled_pixel_row = 31

                if scaled_pixel_col < 0:
                    scaled_pixel_col = 0
                elif scaled_pixel_col >= 32:
                    scaled_pixel_col = 31

                python_datapoint.set_goal_pixel(
                    (scaled_pixel_row, scaled_pixel_col))

                predicted_x = 1.25 - 2.5 * (scaled_pixel_col / 32.0
                                            )  # ranges between -1.25 to 1.25
                predicted_z = 2.5 * (scaled_pixel_row / 32.0
                                     ) - 1.25  # ranges between -1.25 to 1.25

                print("Pixel is %r and Goal is %r and Prediction is %r " %
                      ((scaled_pixel_row, scaled_pixel_col),
                       (goal_loc[0], goal_loc[2]), (predicted_x, predicted_z)))

        end = time.time()
        logging.info("Parsed dataset of size %r in time % seconds",
                     len(dataset), (end - start))

    def convert_to_id(self, instruction):
        tk_seq = instruction.split()
        token_ids = []
        for tk in tk_seq:
            if tk in self.vocab:
                token_ids.append(self.vocab[tk])
            else:
                print("Out of vocabulary word. Ignoring ", tk)
        return token_ids

    def is_close_enough(self, inferred_ix, row, col):
        predicted_row = int(inferred_ix / float(self.final_width))
        predicted_col = inferred_ix % self.final_width

        row_diff = row - predicted_row
        col_diff = col - predicted_col

        dist = math.sqrt(row_diff * row_diff + col_diff * col_diff)

        max_dim = float(max(self.final_height, self.final_width))
        if dist < 0.1 * max_dim:
            return True
        else:
            return False

    def compute_distance_in_real_world(self, inferred_ix, data_point):

        # find predicted pixel
        row = int(inferred_ix / 32) + 0.5
        col = inferred_ix % 32 + 0.5

        # convert to real world location
        predicted_x = 1.25 - 2.5 * (col / 32.0)  # ranges between -1.25 to 1.25
        predicted_z = 2.5 * (row / 32.0) - 1.25  # ranges between -1.25 to 1.25

        # gold location
        gold_x, _, gold_z = data_point.goal_location

        # print("Predicted Goal %r and Goal is %r " % ((row, col), data_point.goal_pixel))
        # print("Predicted Goal Location %r and Goal Location %r" % ((predicted_x, predicted_z), (gold_x, gold_z)))

        l2_distance = math.sqrt((predicted_x - gold_x) *
                                (predicted_x - gold_x) +
                                (predicted_z - gold_z) *
                                (predicted_z - gold_z))

        block_size = 0.1524
        bisk_distance = l2_distance / block_size
        return bisk_distance

    def get_inferred_value(self, volatile):

        # Mode setting
        inferred_ix = int(
            torch.max(volatile["attention_logits"],
                      0)[1].data.cpu().numpy()[0])
        return inferred_ix, None

    def save_attention_prob(self,
                            image,
                            attention_prob,
                            instruction,
                            goal_prob=None):

        image_flipped = image.swapaxes(0, 1).swapaxes(1, 2)
        image_flipped = scipy.misc.imresize(image_flipped, (128, 128))
        attention_prob = attention_prob.cpu().data.numpy()
        resized_kernel = scipy.misc.imresize(attention_prob, (128, 128))
        if goal_prob is not None:
            goal_location = goal_prob.cpu().data.numpy()
            if np.sum(goal_location) > 0.01:
                for i in range(0, 32):
                    for j in range(0, 32):
                        if goal_location[i][j] < 0.01:
                            goal_location[i][j] = 0.0
                goal_location = scipy.misc.imresize(goal_location, (128, 128))
        else:
            goal_location = None

        f, axarr = plt.subplots(1, 2)
        if instruction is not None:
            f.suptitle(instruction)
        axarr[0].set_title("Predicted Attention")
        axarr[0].imshow(image_flipped)
        axarr[0].imshow(resized_kernel, cmap='jet', alpha=0.5)
        axarr[1].set_title("Gold Attention (Goal)")
        axarr[1].imshow(image_flipped)
        if goal_location is not None:
            axarr[1].imshow(goal_location, cmap='jet', alpha=0.5)
        plt.savefig("./attention_prob/image_" + str(self.global_id) + ".png")
        plt.clf()

    def show_image(self, goal, predicted_goal, start_pos, instruction):
        self.global_id += 1

        # image_flipped = image.swapaxes(0, 1).swapaxes(1, 2)
        # image_flipped = scipy.misc.imresize(image_flipped, (128, 128 * 6))
        goal_map = np.zeros((50, 50))
        predicted_goal_map = np.zeros((50, 50))

        x_1, y_1 = goal
        x_2, y_2 = predicted_goal
        x_3, y_3, _ = start_pos

        x_1 = min(x_1, 274.99)
        y_1 = min(y_1, 274.99)
        x_2 = min(x_2, 274.99)
        y_2 = min(y_2, 274.99)
        x_3 = min(x_3, 274.99)
        y_3 = min(y_3, 274.99)

        print(" %r %r %r %r " % (x_1, y_1, x_2, y_2))
        assert 225.0 <= x_1 <= 275.0
        assert 225.0 <= x_2 <= 275.0
        assert 225.0 <= x_3 <= 275.0
        assert 225.0 <= y_1 <= 275.0
        assert 225.0 <= y_2 <= 275.0
        assert 225.0 <= y_3 <= 275.0

        i1, j1 = int((x_1 - 225.0)), int((y_1 - 225.0))
        i2, j2 = int((x_2 - 225.0)), int((y_2 - 225.0))
        i3, j3 = int((x_3 - 225.0)), int((y_3 - 225.0))

        goal_map[i1, j1] = 1.0
        goal_map[i3, j3] = 0.75
        predicted_goal_map[i2, j2] = 1.0
        predicted_goal_map[i3, j3] = 0.75

        f, axarr = plt.subplots(1, 2)
        if instruction is not None:
            f.suptitle(instruction)
        axarr[0].set_title("Predicted Goal")
        # axarr[0].imshow(image_flipped)
        axarr[0].imshow(predicted_goal_map, cmap='jet', alpha=0.5)
        axarr[1].set_title("Gold Goal")
        # axarr[1].imshow(image_flipped)
        axarr[1].imshow(goal_map, cmap='jet', alpha=0.5)
        plt.savefig("./attention_prob/image_" + str(self.global_id) +
                    "_maps.png")
        plt.clf()

    def interactive_shell(self, train_dataset, train_images):

        traj_len = len(train_dataset)
        keep = False
        image_id = 1
        while True:

            # Sample a random dataset
            if not keep:
                ix = random.randint(0, traj_len - 1)
            data_point = train_dataset[ix]
            image = train_images[ix][0]

            # Show the image in pyplot
            plt.imshow(image.swapaxes(0, 1).swapaxes(1, 2))
            plt.ion()
            plt.show()

            # Get the instruction
            print("Enter the instruction below (q or quit to quit)\n")
            print("Sample instruction is ",
                  instruction_to_string(data_point.instruction, self.config))
            while True:
                instruction = input()
                if instruction == "q" or instruction == "quit":
                    break
                elif len(instruction) == 0:
                    print("Enter a non-empty instruction (q or quit to quit)")
                else:
                    break

            instruction_id = self.convert_to_id(instruction)
            state = AgentObservedState(instruction=instruction_id,
                                       config=self.config,
                                       constants=self.constants,
                                       start_image=image,
                                       previous_action=None,
                                       pose=None,
                                       position_orientation=None,
                                       data_point=data_point)

            # Show the attention mask
            _, _, _, volatile = self.model.get_attention_prob(state,
                                                              model_state=None)

            attention_prob = volatile["attention_probs"][:-1].view(
                self.final_height, self.final_width)
            attention_prob = attention_prob.cpu().data.numpy()
            resized_kernel = scipy.misc.imresize(
                attention_prob,
                (self.config["image_height"], self.config["image_width"]))
            plt.clf()
            plt.title(instruction)
            plt.imshow(image.swapaxes(0, 1).swapaxes(1, 2))
            plt.imshow(resized_kernel, cmap="jet", alpha=0.5)

            print(
                "Enter s to save, k to keep working on this environment, sk to do both. Other key to simply continue"
            )
            key_ = input()
            if key_ == "s":
                plt.savefig("interactive_image_" + str(image_id) + ".png")
                image_id += 1

            if key_ == "k":
                keep = True
            else:
                keep = False

            if key_ == "sk":
                plt.savefig("image_" + str(image_id) + ".png")
                image_id += 1
                keep = True

            plt.clf()

    def test(self, tune_dataset, tensorboard):

        total_validation_loss = 0
        total_validation_prob = 0
        total_validation_exact_accuracy = 0
        total_goal_distance = 0
        num_items = 0

        # Next metric measures when the goal is visible and prediction is within 10\% radius
        total_epsilon_accuracy = 0
        num_visible_items = 0

        # Next metric measures distance in real world and only when goal is visible
        total_real_world_distance = 0

        for data_point_ix, data_point in enumerate(tune_dataset):

            model_state = None
            state = AgentObservedState(instruction=data_point.instruction,
                                       config=self.config,
                                       constants=self.constants,
                                       start_image=data_point.start_image,
                                       previous_action=None,
                                       pose=None,
                                       position_orientation=None,
                                       data_point=data_point)

            num_items_ = 0
            sum_loss = 0
            sum_prob = 0
            sum_acc = 0
            sum_dist = 0
            sum_real_world_distance = 0

            row, col = data_point.goal_pixel
            goal = row, col, row, col
            state.goal = goal
            volatile = self.model.get_attention_prob(state, model_state)

            if not self.ignore_none or row is not None:
                gold_ix = row * self.final_width + col
                loss, prob, meta = GoalPrediction.get_loss_and_prob(
                    volatile, goal, self.final_height, self.final_width)
                num_items_ += 1
                sum_loss = sum_loss + float(loss.data.cpu().numpy()[0])
                sum_prob = sum_prob + float(prob.data.cpu().numpy()[0])

                inferred_ix, row_col = self.get_inferred_value(volatile)

                if gold_ix == inferred_ix:
                    sum_acc = sum_acc + 1.0
                if row is not None and col is not None:
                    sum_dist = sum_dist + abs(row - int(round(inferred_ix/self.final_width)))\
                               + abs(col - int(inferred_ix % self.final_height))
                    num_visible_items += 1
                    if self.is_close_enough(inferred_ix, row, col):
                        total_epsilon_accuracy += 1
                    real_world_distance = self.compute_distance_in_real_world(
                        inferred_ix, data_point)
                    sum_real_world_distance += real_world_distance

                    # Save the map
                    instruction_string = instruction_to_string(
                        data_point.instruction, self.config)
                    # goal_x, goal_y = data_point.goal_location
                    # goal_x, goal_y = round(goal_x, 2), round(goal_y, 2)
                    # predicted_goal_x, predicted_goal_y = predicted_goal
                    # predicted_goal_x, predicted_goal_y = round(predicted_goal_x, 2), round(predicted_goal_y, 2)
                    # instruction_string = instruction_string + \
                    #                      "\n (Error: " + str(round(sum_real_world_distance, 2)) + ")" + \
                    #                      "\n %r %r %r %r \n" % (goal_x, goal_y, predicted_goal_x, predicted_goal_y)
                    # self.show_image(data_point.get_destination_list()[-1], predicted_goal, data_point.get_start_pos(),
                    #                 instruction_string)

                    # Save the generated image
                    self.global_id += 1
                    if self.global_id % 25 == 0:
                        goal_prob = GoalPrediction.generate_gold_prob(
                            goal, 32, 32)
                        predicted_goal = (int(inferred_ix / 32),
                                          inferred_ix % 32,
                                          int(inferred_ix / 32),
                                          inferred_ix % 32)
                        predicted_goal_prob = GoalPrediction.generate_gold_prob(
                            predicted_goal, 32, 32)
                        self.save_attention_prob(
                            data_point.start_image,
                            volatile["attention_probs"][:-1].view(32, 32),
                            data_point.instruction_string,
                            goal_prob[:-1].view(32, 32))
                        self.save_attention_prob(
                            data_point.start_image,
                            predicted_goal_prob[:-1].view(32, 32),
                            data_point.instruction_string,
                            goal_prob[:-1].view(32, 32))

            total_validation_loss += sum_loss
            total_validation_prob += sum_prob
            total_goal_distance += sum_dist
            total_validation_exact_accuracy += sum_acc
            total_real_world_distance += sum_real_world_distance
            num_items += num_items_

        mean_total_goal_distance = total_goal_distance / float(
            max(num_items, 1))
        mean_total_validation_loss = total_validation_loss / float(
            max(num_items, 1))
        mean_total_validation_prob = total_validation_prob / float(
            max(num_items, 1))
        mean_total_validation_accuracy = (total_validation_exact_accuracy *
                                          100.0) / float(max(num_items, 1))
        mean_total_epsilon_accuracy = (total_epsilon_accuracy * 100.0) / float(
            max(num_visible_items, 1))
        mean_real_world_distance = total_real_world_distance / float(
            max(num_visible_items, 1))

        logging.info(
            "Mean Test result: L1 Distance is %r, Loss %r, Prob %r, Acc is %r, Epsilon Accuracy is %r"
            % (mean_total_goal_distance, mean_total_validation_loss,
               mean_total_validation_prob, mean_total_validation_accuracy,
               mean_total_epsilon_accuracy))
        logging.info(
            "Num visible items %r, Num Exact Match items is %r, Num epsilon match %r, Num Items is %r "
            % (num_visible_items, total_validation_exact_accuracy,
               total_epsilon_accuracy, num_items))
        logging.info("Num visible items %r, Mean Real World Distance %r " %
                     (num_visible_items, mean_real_world_distance))

        return mean_real_world_distance

    def do_train(self,
                 train_dataset,
                 tune_dataset,
                 experiment_name,
                 save_best_model=False):
        """ Perform training """

        dataset_size = len(train_dataset)
        tensorboard = self.tensorboard

        # Test on tuning data with initialized model
        mean_real_world_distance = self.test(tune_dataset,
                                             tensorboard=tensorboard)
        best_real_world_distance = mean_real_world_distance

        for epoch in range(1, self.max_epoch + 1):

            logging.info("Starting epoch %d", epoch)

            batch_replay_items = []
            best_real_world_distance = min(best_real_world_distance,
                                           mean_real_world_distance)

            for data_point_ix, data_point in enumerate(train_dataset):

                if (data_point_ix + 1) % 100 == 0:
                    logging.info("Done %d out of %d", data_point_ix,
                                 dataset_size)

                model_state = None
                state = AgentObservedState(instruction=data_point.instruction,
                                           config=self.config,
                                           constants=self.constants,
                                           start_image=data_point.start_image,
                                           previous_action=None,
                                           pose=None,
                                           position_orientation=None,
                                           data_point=data_point)

                # Generate attention probabilities
                volatile = self.model.get_attention_prob(state, model_state)
                row, col = data_point.goal_pixel
                goal = row, col, row, col

                # Store it in the replay memory list
                if not self.ignore_none or goal[0] is not None:
                    replay_item = ReplayMemoryItem(state,
                                                   None,
                                                   0,
                                                   volatile=volatile,
                                                   goal=goal)
                    batch_replay_items.append(replay_item)

                # Perform update
                if len(batch_replay_items) > 0:
                    loss_val = self.do_update(batch_replay_items)
                    batch_replay_items = []
                    if tensorboard is not None:
                        tensorboard.log_scalar("Loss", loss_val)
                        if self.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                self.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)
                        if self.goal_prob is not None:
                            goal_prob = float(self.goal_prob.data[0])
                            tensorboard.log_scalar("goal_prob", goal_prob)
                        if self.object_detection_loss is not None:
                            object_detection_loss = float(
                                self.object_detection_loss.data[0])
                            tensorboard.log_scalar("object_detection_loss",
                                                   object_detection_loss)
                        if self.cross_entropy_loss is not None:
                            cross_entropy_loss = float(
                                self.cross_entropy_loss.data[0])
                            tensorboard.log_scalar("Cross_entropy_loss",
                                                   cross_entropy_loss)
                        if self.dist_loss is not None:
                            dist_loss = float(self.dist_loss.data[0])
                            tensorboard.log_scalar("Dist_loss", dist_loss)

            mean_real_world_distance = self.test(tune_dataset,
                                                 tensorboard=tensorboard)

            # Save the model
            if save_best_model:
                if mean_real_world_distance < best_real_world_distance:
                    self.model.save_model(
                        experiment_name +
                        "/goal_prediction_single_supervised_epoch_" +
                        str(epoch))
            else:
                self.model.save_model(
                    experiment_name +
                    "/goal_prediction_single_supervised_epoch_" + str(epoch))
class AsynchronousSupervisedLearning(AbstractLearning):
    """ Perform supervised learning """
    def __init__(self, shared_model, local_model, action_space, meta_data_util,
                 config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_model = shared_model
        self.local_model = local_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.epoch = 0
        self.entropy_coef = constants["entropy_coefficient"]

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.local_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.local_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.local_model,
                num_objects=67,
                camera_angle=60,
                image_height=8,
                image_width=8,
                object_height=0)  #-2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.local_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(self.local_model)
            self.goal_prediction_loss = None

        self.optimizer = optim.Adam(shared_model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_model, self.local_model,
                                  self.calc_loss, self.optimizer, self.config,
                                  self.constants, self.tensorboard)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))

        num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        chosen_log_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))

        gold_distribution = cuda_var(
            torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        self.cross_entropy = -torch.sum(
            gold_distribution * torch.log(mini_batch_action_distribution))
        self.entropy = -torch.mean(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(chosen_log_probs) / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        loss = -objective - self.entropy_coef * self.entropy
        self.ratio = torch.abs(objective) / (self.entropy_coef * self.entropy
                                             )  # we want the ratio to be high

        # loss = -objective + self.entropy_coef * self.cross_entropy

        # Minimize the Factor Entropy if the model is implicit factorization model
        if isinstance(self.local_model,
                      IncrementalModelRecurrentImplicitFactorizationResnet):
            self.mean_factor_entropy = torch.mean(torch.cat(factor_entropy))
            loss = loss + self.mean_factor_entropy
        else:
            self.mean_factor_entropy = None

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            if self.object_detection_loss is not None:
                self.object_detection_loss = self.constants[
                    "object_detection_coeff"] * self.object_detection_loss
                loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        if self.config["do_goal_prediction"]:
            self.goal_prediction_loss, self.goal_prob = self.goal_prediction_calculator.calc_loss(
                batch_replay_items)
            if self.goal_prediction_loss is not None:
                self.goal_prediction_loss = self.constants["goal_prediction_coeff"] * \
                                            self.goal_prediction_loss
                # loss = loss + self.goal_prediction_loss  # * len(batch_replay_items)  # scale the loss
                loss = self.goal_prediction_loss
            else:
                loss = None
        else:
            self.goal_prediction_loss = None

        return loss

    @staticmethod
    def save_goal(batch_replay_items, data_point_ix, trajectory):

        assert len(batch_replay_items) == len(trajectory) + 1
        f = open(
            "../logs/oracle_images/tune_images/example_" + str(data_point_ix) +
            "/goal.txt", "w")
        for item in batch_replay_items:
            row, col, row_real, col_real = item.goal
            f.write("%r %r %r %r\n" % (row, col, row_real, col_real))
        f.flush()
        f.close()

    @staticmethod
    def do_train(shared_model,
                 config,
                 action_space,
                 meta_data_util,
                 constants,
                 train_dataset,
                 tune_dataset,
                 experiment,
                 experiment_name,
                 rank,
                 server,
                 logger,
                 model_type,
                 use_pushover=False):
        try:
            AsynchronousSupervisedLearning.do_train_(
                shared_model, config, action_space, meta_data_util, constants,
                train_dataset, tune_dataset, experiment, experiment_name, rank,
                server, logger, model_type, use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_train_(shared_model,
                  config,
                  action_space,
                  meta_data_util,
                  constants,
                  train_dataset,
                  tune_dataset,
                  experiment,
                  experiment_name,
                  rank,
                  server,
                  logger,
                  model_type,
                  use_pushover=False):

        server.initialize_server()

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = PushoverLogger(experiment_name)
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(config, constants)

        # Create the Agent
        logger.log("STARTING AGENT")
        agent = Agent(server=server,
                      model=local_model,
                      test_policy=test_policy,
                      action_space=action_space,
                      meta_data_util=meta_data_util,
                      config=config,
                      constants=constants)
        logger.log("Created Agent...")

        action_counts = [0] * action_space.num_actions()
        max_epochs = constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        # Create the learner to compute the loss
        learner = AsynchronousSupervisedLearning(shared_model, local_model,
                                                 action_space, meta_data_util,
                                                 config, constants,
                                                 tensorboard)

        # Launch unity
        launch_k_unity_builds([config["port"]],
                              "./simulators/NavDroneLinuxBuild.x86_64")

        for epoch in range(1, max_epochs + 1):

            learner.epoch = epoch

            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_model.load_from_state_dict(shared_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logger.log("Done %d out of %d" %
                               (data_point_ix, dataset_size))
                    logger.log("Training data action counts %r" %
                               action_counts)

                num_actions = 0
                trajectory = data_point.get_trajectory()
                image, metadata = agent.server.reset_receive_feedback(
                    data_point)

                pose = int(metadata["y_angle"] / 15.0)
                position_orientation = (metadata["x_pos"], metadata["z_pos"],
                                        metadata["y_angle"])
                state = AgentObservedState(
                    instruction=data_point.instruction,
                    config=config,
                    constants=constants,
                    start_image=image,
                    previous_action=None,
                    pose=pose,
                    position_orientation=position_orientation,
                    data_point=data_point)

                model_state = None
                batch_replay_items = []
                total_reward = 0

                for action in trajectory:

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, volatile = \
                        local_model.get_probs(state, model_state)

                    action_counts[action] += 1

                    # Generate goal
                    if config["do_goal_prediction"]:
                        goal = learner.goal_prediction_calculator.get_goal_location(
                            metadata, data_point, 8, 8)
                        # learner.goal_prediction_calculator.save_attention_prob(image, volatile)
                        # time.sleep(5)
                    else:
                        goal = None

                    # Send the action and get feedback
                    image, reward, metadata = agent.server.send_action_receive_feedback(
                        action)

                    # Store it in the replay memory list
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   reward,
                                                   log_prob=log_probabilities,
                                                   volatile=volatile,
                                                   goal=goal)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    pose = int(metadata["y_angle"] / 15.0)
                    position_orientation = (metadata["x_pos"],
                                            metadata["z_pos"],
                                            metadata["y_angle"])
                    state = state.update(
                        image,
                        action,
                        pose=pose,
                        position_orientation=position_orientation,
                        data_point=data_point)

                    num_actions += 1
                    total_reward += reward

                # Sample action using the policy
                log_probabilities, model_state, image_emb_seq, volatile = \
                    local_model.get_probs(state, model_state)

                # Generate goal
                if config["do_goal_prediction"]:
                    goal = learner.goal_prediction_calculator.get_goal_location(
                        metadata, data_point, 8, 8)
                    # learner.goal_prediction_calculator.save_attention_prob(image, volatile)
                    # time.sleep(5)
                else:
                    goal = None

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                if tensorboard is not None:
                    tensorboard.log_all_train_errors(
                        metadata["edit_dist_error"],
                        metadata["closest_dist_error"],
                        metadata["stop_dist_error"])

                # Store it in the replay memory list
                replay_item = ReplayMemoryItem(
                    state,
                    action_space.get_stop_action_index(),
                    reward,
                    log_prob=log_probabilities,
                    volatile=volatile,
                    goal=goal)
                batch_replay_items.append(replay_item)

                ###########################################3
                AsynchronousSupervisedLearning.save_goal(
                    batch_replay_items, data_point_ix, trajectory)
                ###########################################3

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:  # 32:
                    loss_val = learner.do_update(batch_replay_items)
                    # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                    # del batch_replay_items[:]  # in place list clear

                    if tensorboard is not None:
                        cross_entropy = float(learner.cross_entropy.data[0])
                        tensorboard.log(cross_entropy, loss_val, 0)
                        entropy = float(
                            learner.entropy.data[0]) / float(num_actions + 1)
                        tensorboard.log_scalar("entropy", entropy)
                        tensorboard.log_scalar("total_reward", total_reward)

                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar(
                            "Abs_objective_to_entropy_ratio", ratio)

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(
                                learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(
                                action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(
                                learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(
                                temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(
                                learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(
                                object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(
                                learner.symbolic_language_prediction_loss.
                                data[0])
                            tensorboard.log_scalar(
                                "sym_language_prediction_loss",
                                symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)
                        if learner.goal_prob is not None:
                            goal_prob = float(learner.goal_prob.data[0])
                            tensorboard.log_scalar("goal_prob", goal_prob)
                        if learner.mean_factor_entropy is not None:
                            mean_factor_entropy = float(
                                learner.mean_factor_entropy.data[0])
                            tensorboard.log_factor_entropy_loss(
                                mean_factor_entropy)

            # Save the model
            local_model.save_model(experiment + "/supervised_learning_" +
                                   str(rank) + "_epoch_" + str(epoch))
            logger.log("Training data action counts %r" % action_counts)

            if tune_dataset_size > 0:
                # Test on tuning data
                agent.test_goal_prediction(tune_dataset,
                                           tensorboard=tensorboard,
                                           logger=logger,
                                           pushover_logger=pushover_logger)
class AsynchronousTwoStageContextualBandit(AbstractLearning):
    """ Perform Contextual Bandit learning (Kakade and Langford (circa 2006) & Misra, Langford and Artzi EMNLP 2017)
    on the two stage model. """
    def __init__(self, shared_navigator_model, local_navigator_model,
                 shared_predictor_model, local_predictor_model, action_space,
                 meta_data_util, config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_navigator_model = shared_navigator_model
        self.local_navigator_model = local_navigator_model
        self.shared_predictor_model = shared_predictor_model
        self.local_predictor_model = local_predictor_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.ratio = None
        self.epoch = 0
        self.entropy_coef = constants["entropy_coefficient"]

        self.image_channels, self.image_height, self.image_width = shared_navigator_model.image_module.get_final_dimension(
        )

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.local_navigator_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.local_navigator_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.local_navigator_model,
                num_objects=67,
                camera_angle=60,
                image_height=self.image_height,
                image_width=self.image_width,
                object_height=0)  # -2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.local_navigator_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(
                self.local_navigator_model, self.image_height,
                self.image_width)
            self.goal_prediction_loss = None

        parameters = self.shared_navigator_model.get_parameters()
        parameters.extend(self.shared_predictor_model.get_parameters())

        self.optimizer = optim.Adam(parameters, lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_navigator_model,
                                  self.local_navigator_model, self.calc_loss,
                                  self.optimizer, self.config, self.constants,
                                  self.tensorboard)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        immediate_rewards = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        chosen_log_goal_prob = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            immediate_rewards.append(replay_item.get_reward())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())
            chosen_log_goal_prob.append(
                replay_item.get_volatile_features()["goal_sample_prob"])

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))
        immediate_rewards = cuda_var(
            torch.from_numpy(np.array(immediate_rewards)).float())

        num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        # model_log_prob_batch = self.model.get_probs_batch(agent_observation_state_ls)
        chosen_log_action_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))

        # Take the probability of goal generation into account
        chosen_log_goal_prob = torch.cat(chosen_log_goal_prob)

        chosen_log_probs = chosen_log_action_probs.view(
            -1) + chosen_log_goal_prob.view(-1)
        reward_log_probs = immediate_rewards * chosen_log_probs

        gold_distribution = cuda_var(
            torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        self.cross_entropy = -torch.sum(
            gold_distribution * torch.log(mini_batch_action_distribution))
        # self.entropy = -torch.mean(torch.sum(model_log_prob_batch * model_prob_batch, 1))
        self.entropy = -torch.sum(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(reward_log_probs)  # / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        entropy_coef = max(0, self.entropy_coef - self.epoch * 0.01)
        loss = -objective - entropy_coef * self.entropy
        self.ratio = torch.abs(objective) / (entropy_coef * self.entropy
                                             )  # we want the ratio to be high

        # loss = -objective + self.entropy_coef * self.cross_entropy

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            if self.object_detection_loss is not None:
                self.object_detection_loss = self.constants[
                    "object_detection_coeff"] * self.object_detection_loss
                loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        if self.config["do_goal_prediction"]:
            self.goal_prediction_loss, _, _ = self.goal_prediction_calculator.calc_loss(
                batch_replay_items)
            if self.goal_prediction_loss is not None:
                self.goal_prediction_loss = self.constants["goal_prediction_coeff"] * \
                                            self.goal_prediction_loss
                loss = loss + self.goal_prediction_loss  # * len(batch_replay_items)  # scale the loss
        else:
            self.goal_prediction_loss = None

        return loss

    def _sample_goal(self, exploration_image, data_point, panaroma=True):

        state = AgentObservedState(
            instruction=data_point.instruction,
            config=self.config,
            constants=self.constants,
            start_image=exploration_image,
            previous_action=None,
            pose=None,
            position_orientation=data_point.get_start_pos(),
            data_point=data_point)

        volatile = self.local_predictor_model.get_attention_prob(
            state, model_state=None)
        attention_prob = list(
            volatile["attention_probs"].view(-1)[:-1].data.cpu().numpy())
        sampled_ix = gp.sample_action_from_prob(attention_prob)
        sampled_prob = volatile["attention_probs"][sampled_ix]
        #################################################

        # Max pointed about that when inferred ix above is the last value then calculations are buggy. He is right.

        predicted_row = int(sampled_ix / float(192))
        predicted_col = sampled_ix % 192
        screen_pos = (predicted_row, predicted_col)

        if panaroma:
            # Index of the 6 image where the goal is
            region_index = int(predicted_col / 32)
            predicted_col = predicted_col % 32  # Column within that image where the goal is
            pos = data_point.get_start_pos()
            new_pos_angle = GoalPredictionSingle360ImageSupervisedLearningFromDisk.\
                get_new_pos_angle_from_region_index(region_index, pos)
            metadata = {
                "x_pos": pos[0],
                "z_pos": pos[1],
                "y_angle": new_pos_angle
            }
        else:
            pos = data_point.get_start_pos()
            metadata = {"x_pos": pos[0], "z_pos": pos[1], "y_angle": pos[2]}

        row, col = predicted_row + 0.5, predicted_col + 0.5

        start_pos = current_pos_from_metadata(metadata)
        start_pose = current_pose_from_metadata(metadata)

        goal_pos = data_point.get_destination_list()[-1]
        height_drone = 2.5
        x_gen, z_gen = get_inverse_object_position(
            row, col, height_drone, 30, 32, 32,
            (start_pos[0], start_pos[1], start_pose))
        predicted_goal_pos = (x_gen, z_gen)
        x_goal, z_goal = goal_pos

        x_diff = x_gen - x_goal
        z_diff = z_gen - z_goal

        dist = math.sqrt(x_diff * x_diff + z_diff * z_diff)

        return predicted_goal_pos, dist, screen_pos, sampled_prob

    @staticmethod
    def do_train(shared_navigator_model,
                 shared_predictor_model,
                 config,
                 action_space,
                 meta_data_util,
                 constants,
                 train_dataset,
                 tune_dataset,
                 experiment,
                 experiment_name,
                 rank,
                 server,
                 logger,
                 navigator_model_type,
                 predictor_model_type,
                 use_pushover=False):
        try:
            AsynchronousTwoStageContextualBandit.do_train_(
                shared_navigator_model, shared_predictor_model, config,
                action_space, meta_data_util, constants, train_dataset,
                tune_dataset, experiment, experiment_name, rank, server,
                logger, navigator_model_type, predictor_model_type,
                use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_train_(shared_navigator_model,
                  shared_predictor_model,
                  config,
                  action_space,
                  meta_data_util,
                  constants,
                  train_dataset,
                  tune_dataset,
                  experiment,
                  experiment_name,
                  rank,
                  server,
                  logger,
                  navigator_model_type,
                  predictor_model_type,
                  use_pushover=False):

        server.initialize_server()

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = PushoverLogger(experiment_name)
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_predictor_model = predictor_model_type(
            config,
            constants,
            final_model_type="unet-positional-encoding",
            final_dimension=(64, 32, 32 * 6))
        local_navigator_model = navigator_model_type(config, constants)
        # local_model.train()

        # Create the Agent
        logger.log("STARTING AGENT")
        agent = PredictorPlannerAgent(server=server,
                                      predictor_model=local_predictor_model,
                                      model=local_navigator_model,
                                      test_policy=test_policy,
                                      action_space=action_space,
                                      meta_data_util=meta_data_util,
                                      config=config,
                                      constants=constants)
        logger.log("Created Agent...")

        action_counts = [0] * action_space.num_actions()
        max_epochs = constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        # Create the learner to compute the loss
        learner = AsynchronousTwoStageContextualBandit(
            shared_navigator_model, local_navigator_model,
            shared_predictor_model, local_predictor_model, action_space,
            meta_data_util, config, constants, tensorboard)

        # Launch unity
        launch_k_unity_builds([config["port"]],
                              "./simulators/NavDroneLinuxBuild.x86_64")

        for epoch in range(1, max_epochs + 1):

            learner.epoch = epoch
            task_completion_accuracy = 0
            mean_stop_dist_error = 0
            stop_dist_errors = []
            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_navigator_model.load_from_state_dict(
                    shared_navigator_model.get_state_dict())
                local_predictor_model.load_from_state_dict(
                    shared_predictor_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logger.log("Done %d out of %d" %
                               (data_point_ix, dataset_size))
                    logger.log("Training data action counts %r" %
                               action_counts)

                num_actions = 0
                max_num_actions = constants["horizon"] + constants[
                    "max_extra_horizon"]

                image, metadata = agent.server.reset_receive_feedback(
                    data_point)

                # Generate goal probability
                # Test image
                panorama = agent.get_exploration_image()

                # Sample a goal location and compute 3D mapping
                predicted_goal, predictor_error, predicted_pixel, sample_prob = learner._sample_goal(
                    panorama, data_point, panaroma=True)

                pose = int(metadata["y_angle"] / 15.0)
                position_orientation = (metadata["x_pos"], metadata["z_pos"],
                                        metadata["y_angle"])
                state = AgentObservedState(
                    instruction=data_point.instruction,
                    config=config,
                    constants=constants,
                    start_image=image,
                    previous_action=None,
                    pose=pose,
                    position_orientation=position_orientation,
                    data_point=data_point)
                current_bot_location = metadata["x_pos"], metadata["z_pos"]
                current_bot_pose = metadata["y_angle"]
                state.goal = PredictorPlannerAgent.get_goal_location(
                    current_bot_location, current_bot_pose, predicted_goal, 32,
                    32)

                model_state = None
                batch_replay_items = []
                total_reward = 0
                forced_stop = True

                while num_actions < max_num_actions:

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, volatile = \
                        local_navigator_model.get_probs(state, model_state)
                    probabilities = list(torch.exp(log_probabilities.data))[0]

                    # Sample action from the probability
                    action = gp.sample_action_from_prob(probabilities)
                    action_counts[action] += 1

                    if action == action_space.get_stop_action_index():
                        forced_stop = False
                        break

                    # Send the action and get feedback
                    image, reward, metadata = agent.server.send_action_receive_feedback(
                        action)

                    # Store it in the replay memory list
                    volatile["goal_sample_prob"] = sample_prob
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   reward,
                                                   log_prob=log_probabilities,
                                                   volatile=volatile)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    pose = int(metadata["y_angle"] / 15.0)
                    position_orientation = (metadata["x_pos"],
                                            metadata["z_pos"],
                                            metadata["y_angle"])
                    state = state.update(
                        image,
                        action,
                        pose=pose,
                        position_orientation=position_orientation,
                        data_point=data_point)

                    current_bot_location = metadata["x_pos"], metadata["z_pos"]
                    current_bot_pose = metadata["y_angle"]
                    state.goal = PredictorPlannerAgent.get_goal_location(
                        current_bot_location, current_bot_pose, predicted_goal,
                        32, 32)

                    num_actions += 1
                    total_reward += reward

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                if metadata["stop_dist_error"] < 5.0:
                    task_completion_accuracy += 1
                mean_stop_dist_error += metadata["stop_dist_error"]
                stop_dist_errors.append(metadata["stop_dist_error"])

                if tensorboard is not None:
                    tensorboard.log_all_train_errors(
                        metadata["edit_dist_error"],
                        metadata["closest_dist_error"],
                        metadata["stop_dist_error"])

                # Store it in the replay memory list
                if not forced_stop:
                    volatile["goal_sample_prob"] = sample_prob
                    replay_item = ReplayMemoryItem(
                        state,
                        action_space.get_stop_action_index(),
                        reward,
                        log_prob=log_probabilities,
                        volatile=volatile)
                    batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:  # 32:
                    loss_val = learner.do_update(batch_replay_items)
                    # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                    # del batch_replay_items[:]  # in place list clear

                    if tensorboard is not None:
                        tensorboard.log_scalar("gold_sample_prob",
                                               float(sample_prob.data[0]))
                        tensorboard.log_scalar("predicted_error",
                                               predictor_error)
                        cross_entropy = float(learner.cross_entropy.data[0])
                        tensorboard.log(cross_entropy, loss_val, 0)
                        entropy = float(
                            learner.entropy.data[0]) / float(num_actions + 1)
                        tensorboard.log_scalar("entropy", entropy)
                        tensorboard.log_scalar("total_reward", total_reward)

                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar(
                            "Abs_objective_to_entropy_ratio", ratio)

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(
                                learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(
                                action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(
                                learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(
                                temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(
                                learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(
                                object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(
                                learner.symbolic_language_prediction_loss.
                                data[0])
                            tensorboard.log_scalar(
                                "sym_language_prediction_loss",
                                symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)

            # Save the model
            local_navigator_model.save_model(experiment +
                                             "/navigator_contextual_bandit_" +
                                             str(rank) + "_epoch_" +
                                             str(epoch))
            local_predictor_model.save_model(experiment +
                                             "/predictor_contextual_bandit_" +
                                             str(rank) + "_epoch_" +
                                             str(epoch))
            logger.log("Training data action counts %r" % action_counts)
            mean_stop_dist_error = mean_stop_dist_error / float(
                len(train_dataset))
            task_completion_accuracy = (task_completion_accuracy *
                                        100.0) / float(len(train_dataset))
            logger.log("Training: Mean stop distance error %r" %
                       mean_stop_dist_error)
            logger.log("Training: Task completion accuracy %r " %
                       task_completion_accuracy)
            bins = range(0, 80, 3)  # range of distance
            histogram, _ = np.histogram(stop_dist_errors, bins)
            logger.log("Histogram of train errors %r " % histogram)

            if tune_dataset_size > 0:
                # Test on tuning data
                agent.test(tune_dataset,
                           tensorboard=tensorboard,
                           logger=logger,
                           pushover_logger=pushover_logger)
Beispiel #10
0
class TmpStreetViewAsynchronousContextualBandit(AbstractLearning):
    """ Temp file with modification for streetview corpus.
    Perform Contextual Bandit learning (Kakade and Langford (circa 2006) & Misra, Langford and Artzi EMNLP 2017) """
    def __init__(self, shared_model, local_model, action_space, meta_data_util,
                 config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.shared_model = shared_model
        self.local_model = local_model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.ratio = None
        self.epoch = 0
        self.entropy_coef = constants["entropy_coefficient"]

        # self.image_channels, self.image_height, self.image_width = shared_model.image_module.get_final_dimension()

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(
                self.local_model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(
                self.local_model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.local_model,
                num_objects=67,
                camera_angle=60,
                image_height=self.image_height,
                image_width=self.image_width,
                object_height=0)  # -2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(
                self.local_model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(
                self.local_model, self.image_height, self.image_width)
            self.goal_prediction_loss = None

        self.optimizer = optim.Adam(shared_model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.shared_model, self.local_model,
                                  self.calc_loss, self.optimizer, self.config,
                                  self.constants, self.tensorboard)

    def calc_loss(self, batch_replay_items):

        agent_observation_state_ls = []
        immediate_rewards = []
        action_batch = []
        log_probabilities = []
        factor_entropy = []
        for replay_item in batch_replay_items:
            agent_observation_state_ls.append(
                replay_item.get_agent_observed_state())
            action_batch.append(replay_item.get_action())
            immediate_rewards.append(replay_item.get_reward())
            log_probabilities.append(replay_item.get_log_prob())
            factor_entropy.append(replay_item.get_factor_entropy())

        log_probabilities = torch.cat(log_probabilities)
        action_batch = cuda_var(torch.from_numpy(np.array(action_batch)))
        immediate_rewards = cuda_var(
            torch.from_numpy(np.array(immediate_rewards)).float())

        num_states = int(action_batch.size()[0])
        model_log_prob_batch = log_probabilities
        # model_log_prob_batch = self.model.get_probs_batch(agent_observation_state_ls)
        chosen_log_probs = model_log_prob_batch.gather(
            1, action_batch.view(-1, 1))
        reward_log_probs = immediate_rewards * chosen_log_probs.view(-1)

        gold_distribution = cuda_var(
            torch.FloatTensor([0.6719, 0.1457, 0.1435, 0.0387]))
        model_prob_batch = torch.exp(model_log_prob_batch)
        mini_batch_action_distribution = torch.mean(model_prob_batch, 0)

        self.cross_entropy = -torch.sum(
            gold_distribution * torch.log(mini_batch_action_distribution))
        # self.entropy = -torch.mean(torch.sum(model_log_prob_batch * model_prob_batch, 1))
        self.entropy = -torch.sum(
            torch.sum(model_log_prob_batch * model_prob_batch, 1))
        objective = torch.sum(reward_log_probs)  # / num_states
        # Essentially we want the objective to increase and cross entropy to decrease
        entropy_coef = max(0, self.entropy_coef - self.epoch * 0.01)
        loss = -objective - entropy_coef * self.entropy
        self.ratio = torch.abs(objective) / (entropy_coef * self.entropy
                                             )  # we want the ratio to be high

        # loss = -objective + self.entropy_coef * self.cross_entropy

        if self.config["do_action_prediction"]:
            self.action_prediction_loss = self.action_prediction_loss_calculator.calc_loss(
                batch_replay_items)
            if self.action_prediction_loss is not None:
                self.action_prediction_loss = self.constants[
                    "action_prediction_coeff"] * self.action_prediction_loss
                loss = loss + self.action_prediction_loss
        else:
            self.action_prediction_loss = None

        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss = self.temporal_autoencoder_loss_calculator.calc_loss(
                batch_replay_items)
            if self.temporal_autoencoder_loss is not None:
                self.temporal_autoencoder_loss = \
                    self.constants["temporal_autoencoder_coeff"] * self.temporal_autoencoder_loss
                loss = loss + self.temporal_autoencoder_loss
        else:
            self.temporal_autoencoder_loss = None

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(
                batch_replay_items)
            if self.object_detection_loss is not None:
                self.object_detection_loss = self.constants[
                    "object_detection_coeff"] * self.object_detection_loss
                loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss = \
                self.symbolic_language_prediction_loss_calculator.calc_loss(batch_replay_items)
            self.symbolic_language_prediction_loss = self.constants["symbolic_language_prediction_coeff"] * \
                                                     self.symbolic_language_prediction_loss
            loss = loss + self.symbolic_language_prediction_loss
        else:
            self.symbolic_language_prediction_loss = None

        if self.config["do_goal_prediction"]:
            self.goal_prediction_loss, _, _ = self.goal_prediction_calculator.calc_loss(
                batch_replay_items)
            if self.goal_prediction_loss is not None:
                self.goal_prediction_loss = self.constants["goal_prediction_coeff"] * \
                                            self.goal_prediction_loss
                loss = loss + self.goal_prediction_loss  # * len(batch_replay_items)  # scale the loss
        else:
            self.goal_prediction_loss = None

        return loss

    @staticmethod
    def do_train(shared_model,
                 config,
                 action_space,
                 meta_data_util,
                 constants,
                 train_dataset,
                 tune_dataset,
                 experiment,
                 experiment_name,
                 rank,
                 server,
                 logger,
                 model_type,
                 use_pushover=False):
        try:
            TmpStreetViewAsynchronousContextualBandit.do_train_(
                shared_model, config, action_space, meta_data_util, constants,
                train_dataset, tune_dataset, experiment, experiment_name, rank,
                server, logger, model_type, use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_train_(shared_model,
                  config,
                  action_space,
                  meta_data_util,
                  constants,
                  train_dataset,
                  tune_dataset,
                  experiment,
                  experiment_name,
                  rank,
                  server,
                  logger,
                  model_type,
                  use_pushover=False):

        server.initialize_server()

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = PushoverLogger(experiment_name)
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(config, constants)
        # local_model.train()

        # Create the Agent
        logger.log("STARTING AGENT")
        agent = Agent(server=server,
                      model=local_model,
                      test_policy=test_policy,
                      action_space=action_space,
                      meta_data_util=meta_data_util,
                      config=config,
                      constants=constants)
        logger.log("Created Agent...")

        action_counts = [0] * action_space.num_actions()
        action_rewards = [0.0] * action_space.num_actions()
        max_epochs = constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        # Create the learner to compute the loss
        learner = TmpStreetViewAsynchronousContextualBandit(
            shared_model, local_model, action_space, meta_data_util, config,
            constants, tensorboard)

        for epoch in range(1, max_epochs + 1):

            learner.epoch = epoch
            task_completion_accuracy = 0
            mean_stop_dist_error = 0

            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_model.load_from_state_dict(shared_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logger.log("Done %d out of %d" %
                               (data_point_ix, dataset_size))
                    logger.log("Training data action counts %r" %
                               action_counts)
                    mean_action_reward = [
                        action_sum / max(1.0, action_count) for (
                            action_sum,
                            action_count) in zip(action_rewards, action_counts)
                    ]
                    logger.log("Training data action rewards %r" %
                               mean_action_reward)

                num_actions = 0
                max_num_actions = constants["horizon"] + constants[
                    "max_extra_horizon"]

                image, metadata = agent.server.reset_receive_feedback(
                    data_point)

                state = AgentObservedState(instruction=data_point.instruction,
                                           config=config,
                                           constants=constants,
                                           start_image=image,
                                           previous_action=None,
                                           data_point=data_point)
                # state.goal = GoalPrediction.get_goal_location(metadata, data_point,
                #                                               learner.image_height, learner.image_width)

                model_state = None
                batch_replay_items = []
                total_reward = 0
                forced_stop = True

                while num_actions < max_num_actions:

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, volatile = \
                        local_model.get_probs(state, model_state)
                    probabilities = list(torch.exp(log_probabilities.data))[0]

                    # Sample action from the probability
                    action = gp.sample_action_from_prob(probabilities)
                    action_counts[action] += 1

                    # Generate goal
                    if config["do_goal_prediction"]:
                        goal = learner.goal_prediction_calculator.get_goal_location(
                            metadata, data_point, learner.image_height,
                            learner.image_width)
                    else:
                        goal = None

                    if action == action_space.get_stop_action_index():
                        forced_stop = False
                        break

                    # Send the action and get feedback
                    image, reward, metadata = agent.server.send_action_receive_feedback(
                        action)
                    action_rewards[action] += reward

                    # Store it in the replay memory list
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   reward,
                                                   log_prob=log_probabilities,
                                                   volatile=volatile,
                                                   goal=goal)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    state = state.update(image, action, data_point=data_point)
                    # state.goal = GoalPrediction.get_goal_location(metadata, data_point,
                    #                                               learner.image_height, learner.image_width)

                    num_actions += 1
                    total_reward += reward

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                if metadata["navigation_error"] <= 5.0:
                    task_completion_accuracy += 1
                mean_stop_dist_error += metadata["navigation_error"]

                if tensorboard is not None:
                    tensorboard.log_scalar("navigation_error",
                                           metadata["navigation_error"])

                # Store it in the replay memory list
                if not forced_stop:
                    replay_item = ReplayMemoryItem(
                        state,
                        action_space.get_stop_action_index(),
                        reward,
                        log_prob=log_probabilities,
                        volatile=volatile,
                        goal=goal)
                    batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:  # 32:
                    loss_val = learner.do_update(batch_replay_items)
                    # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                    # del batch_replay_items[:]  # in place list clear

                    if tensorboard is not None:
                        cross_entropy = float(learner.cross_entropy.data[0])
                        tensorboard.log(cross_entropy, loss_val, 0)
                        entropy = float(
                            learner.entropy.data[0]) / float(num_actions + 1)
                        tensorboard.log_scalar("entropy", entropy)
                        tensorboard.log_scalar("total_reward", total_reward)

                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar(
                            "Abs_objective_to_entropy_ratio", ratio)

                        logger.log(
                            "Avg. Entropy %r, Total Reward %r, Rollout Length %r, stop-error %r, ratio %r "
                            % (entropy, total_reward, num_actions + 1,
                               metadata["navigation_error"], ratio))

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(
                                learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(
                                action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(
                                learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(
                                temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(
                                learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(
                                object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(
                                learner.symbolic_language_prediction_loss.
                                data[0])
                            tensorboard.log_scalar(
                                "sym_language_prediction_loss",
                                symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)

            # Save the model
            local_model.save_model(experiment + "/contextual_bandit_" +
                                   str(rank) + "_epoch_" + str(epoch))
            logger.log("Training data action counts %r" % action_counts)
            mean_action_reward = [
                action_sum / max(1.0, action_count)
                for (action_sum,
                     action_count) in zip(action_rewards, action_counts)
            ]
            logger.log("Training data action rewards %r" % mean_action_reward)
            mean_stop_dist_error = mean_stop_dist_error / float(
                len(train_dataset))
            task_completion_accuracy = (task_completion_accuracy *
                                        100.0) / float(len(train_dataset))
            logger.log("Training: Mean stop distance error %r" %
                       mean_stop_dist_error)
            logger.log("Training: Task completion accuracy %r " %
                       task_completion_accuracy)

            if tune_dataset_size > 0:
                logger.log("Evaluating on the tune split")
                # Test on tuning data
                agent.test(tune_dataset,
                           tensorboard=tensorboard,
                           logger=logger,
                           pushover_logger=pushover_logger)

                # Test on as elected train set
                # logger.log("Evaluating on first 50 examples in the train split.")
                # agent.test(train_dataset[0:50], tensorboard=tensorboard,
                #            logger=logger, pushover_logger=pushover_logger)

    @staticmethod
    def do_test(shared_model,
                config,
                action_space,
                meta_data_util,
                constants,
                test_dataset,
                experiment_name,
                rank,
                server,
                logger,
                model_type,
                use_pushover=False):
        try:
            TmpStreetViewAsynchronousContextualBandit.do_test_(
                shared_model, config, action_space, meta_data_util, constants,
                test_dataset, experiment_name, rank, server, logger,
                model_type, use_pushover)
        except Exception:
            exc_info = sys.exc_info()
            traceback.print_exception(*exc_info)

    @staticmethod
    def do_test_(shared_model,
                 config,
                 action_space,
                 meta_data_util,
                 constants,
                 test_dataset,
                 experiment_name,
                 rank,
                 server,
                 logger,
                 model_type,
                 use_pushover=False):

        server.initialize_server()

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = PushoverLogger(experiment_name)
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(config, constants)
        # local_model.train()

        # Create the Agent
        logger.log("STARTING AGENT")
        agent = Agent(server=server,
                      model=local_model,
                      test_policy=test_policy,
                      action_space=action_space,
                      meta_data_util=meta_data_util,
                      config=config,
                      constants=constants)
        logger.log("Created Agent...")

        tune_dataset_size = len(test_dataset)

        local_model.load_from_state_dict(shared_model.get_state_dict())

        if tune_dataset_size > 0:
            # Test on tuning data
            agent.test(test_dataset,
                       tensorboard=tensorboard,
                       logger=logger,
                       pushover_logger=pushover_logger)
class GoalPredictionSingle360ImageSupervisedLearningFromDisk(AbstractLearning):
    """ Perform goal prediction on single images (as opposed to doing it for sequence)
    stored on disk and hence does not need client or server. """

    CLOCKWISE, BACKSTICH = range(2)
    image_stich = CLOCKWISE

    MODE, MEAN, REALMEAN, MEANAROUNDMODE = range(4)

    def __init__(self, model, action_space, meta_data_util, config, constants, tensorboard):
        self.max_epoch = constants["max_epochs"]
        self.model = model
        self.action_space = action_space
        self.meta_data_util = meta_data_util
        self.config = config
        self.constants = constants
        self.tensorboard = tensorboard
        self.entropy = None
        self.cross_entropy = None
        self.epoch = 0
        self.global_id = 0
        self.entropy_coef = constants["entropy_coefficient"]
        self.final_num_channels, self.final_height, self.final_width = model.image_module.get_final_dimension()

        self.ignore_none = True
        self.inference_procedure = GoalPredictionSingle360ImageSupervisedLearningFromDisk.MODE

        self.vocab = {}
        vocab_path = config["vocab_file"]
        word_index = 0
        with open(vocab_path) as f:
            for line in f.readlines():
                token = line.strip()
                self.vocab[token] = word_index
                word_index += 1

        # Auxiliary Objectives
        if self.config["do_action_prediction"]:
            self.action_prediction_loss_calculator = ActionPrediction(self.model)
            self.action_prediction_loss = None
        if self.config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_loss_calculator = TemporalAutoEncoder(self.model)
            self.temporal_autoencoder_loss = None
        if self.config["do_object_detection"]:
            self.object_detection_loss_calculator = ObjectPixelIdentification(
                self.model, num_objects=67, camera_angle=60, image_height=self.final_height,
                image_width=self.final_width, object_height=0)  # -2.5)
            self.object_detection_loss = None
        if self.config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_loss_calculator = SymbolicLanguagePrediction(self.model)
            self.symbolic_language_prediction_loss = None
        if self.config["do_goal_prediction"]:
            self.goal_prediction_calculator = GoalPrediction(self.model, self.final_height, self.final_width)
            self.goal_prediction_loss = None

        self.cross_entropy_loss = None
        self.dist_loss = None

        self.optimizer = optim.Adam(model.get_parameters(),
                                    lr=constants["learning_rate"])
        AbstractLearning.__init__(self, self.model, self.calc_loss, self.optimizer,
                                  self.config, self.constants, self.tensorboard)

        logging.info("Created Single Image goal predictor with ignore_none %r", self.ignore_none)

    def calc_loss(self, batch_replay_items):

        # Only compute the goal prediction loss
        # loss = None
        # for replay_item in batch_replay_items:
        #     self.goal_prediction_loss, self.goal_prob, meta = self.goal_prediction_calculator.calc_loss(
        #         [replay_item])
        #     if loss is None:
        #         loss = self.goal_prediction_loss
        #     else:
        #         loss += self.goal_prediction_loss
        #
        # loss = loss / float(len(batch_replay_items))
        self.goal_prediction_loss, self.goal_prob, meta = self.goal_prediction_calculator.calc_loss(batch_replay_items)
        loss = self.goal_prediction_loss

        if self.config["do_object_detection"]:
            self.object_detection_loss = self.object_detection_loss_calculator.calc_loss(batch_replay_items)
            if self.object_detection_loss is not None:
                self.object_detection_loss = self.constants["object_detection_coeff"] * self.object_detection_loss
                loss = loss + self.object_detection_loss
        else:
            self.object_detection_loss = None

        self.cross_entropy_loss = meta["cross_entropy"]
        self.dist_loss = meta["dist_loss"]

        return loss

    @staticmethod
    def get_new_pos_angle(data_point):

        pos = data_point.get_start_pos()
        metadata = {"x_pos": pos[0], "z_pos": pos[1], "y_angle": pos[2]}
        turn_angle = get_turn_angle_from_metadata_datapoint(metadata, data_point)

        assert 180.0 >= turn_angle >= -180.0
        if 30.0 >= turn_angle > -30.0:
            ix = 3  # ix = 0
            mean_turn_angle = 0
        elif 90.0 >= turn_angle > 30.0:
            ix = 4  # ix = 1
            mean_turn_angle = 60
        elif 150.0 >= turn_angle > 90.0:
            ix = 5  # ix = 2
            mean_turn_angle = 120
        elif -30 >= turn_angle > -90.0:
            ix = 2  # ix = 5
            mean_turn_angle = -60
        elif -90.0 >= turn_angle > -150.0:
            ix = 1  # ix = 4
            mean_turn_angle = -120
        else:
            ix = 0  # ix = 3
            mean_turn_angle = 180

        new_pos_angle = pos[2] + mean_turn_angle
        while new_pos_angle < -180:
            new_pos_angle += 360.0
        while new_pos_angle > 180:
            new_pos_angle -= 360.0

        return new_pos_angle, ix

    @staticmethod
    def get_new_pos_angle_from_region_index(region_index, start_pos):

        if region_index == 3:
            mean_turn_angle = 0
        elif region_index == 4:
            mean_turn_angle = 60
        elif region_index == 5:
            mean_turn_angle = 120
        elif region_index == 0:
            mean_turn_angle = 180
        elif region_index == 1:
            mean_turn_angle = -120
        elif region_index == 2:
            mean_turn_angle = -60
        else:
            raise AssertionError("Region index should be in 0 to 6.")

        new_pos_angle = start_pos[2] + mean_turn_angle
        while new_pos_angle < -180:
            new_pos_angle += 360.0
        while new_pos_angle > 180:
            new_pos_angle -= 360.0

        return new_pos_angle

    @staticmethod
    def parse(folder_name, dataset, model, config, format_type="numpy"):

        start = time.time()
        num_channel, height, width = model.image_module.get_final_dimension()

        # Read images
        num_examples = len(os.listdir(folder_name))

        if format_type == "numpy":

            image_dataset = []

            # Read panaroma images
            for i in range(0, num_examples):
                example_folder_name = folder_name + "/example_" + str(i)
                image_np = np.load(example_folder_name + "/image_numpy.npy")

                slices = []
                for i in range(0, 6):
                    slices.append(image_np[i*3:(i + 1)*3, :, :].swapaxes(0, 1).swapaxes(1, 2))  # height x width x 3

                images = [slices[3], slices[4], slices[5], slices[0], slices[1], slices[2]]
                images = np.hstack(images)
                images = images.swapaxes(1, 2).swapaxes(0, 1)
                image_dataset.append([images])

        elif format_type == "png":
            image_dataset = []

            # Read panaroma images
            for i in range(0, num_examples):
                example_folder_name = folder_name + "/example_" + str(i)
                images = []

                # image_order = range(0, 6)  # clockwise
                image_order = [3, 4, 5, 0, 1, 2]

                for ix in image_order:  # panaroma consists of 6 images stitched together
                    img = scipy.misc.imread(example_folder_name + "/image_" + str(ix) + ".png")
                    images.append(img)
                images = np.hstack(images)
                images = images.swapaxes(1, 2).swapaxes(0, 1)
                image_dataset.append([images])
        else:
            raise AssertionError("")

        # Read the goal state. The data for the single image can be
        # directly computed and does not need to be saved.
        goal_dataset = []
        for i in range(0, num_examples):

            data_point = dataset[i]
            new_pos_angle, ix = GoalPredictionSingle360ImageSupervisedLearningFromDisk.get_new_pos_angle(data_point)

            # Modify the pos to turn towards the image so we can compute the goal location relative to a single image.
            pos = data_point.get_start_pos()
            new_pos = (pos[0], pos[1], new_pos_angle)
            original_start_pos = pos
            data_point.start_pos = new_pos
            metadata = {"x_pos": new_pos[0], "z_pos": new_pos[1], "y_angle": new_pos[2]}
            new_turn_angle = get_turn_angle_from_metadata_datapoint(metadata, data_point)
            assert 30.0 >= new_turn_angle >= -30.0, "Found turn angle of " + str(new_turn_angle)
            goal_location = [GoalPrediction.get_goal_location(metadata, data_point, height=32, width=32)]
            _, _, row, col = goal_location[0]
            data_point.start_pos = original_start_pos

            # print("Drone's original angle is %r, New Pos Angle is %r ", original_start_pos[2], new_pos_angle)
            # print("Index is %r, Goal is %r " % (ix, goal_location[0]))
            if row is not None and col is not None:
                row = row
                col = col + ix * 32.0
                row_new, col_new, row1, col1 = [int(round(row)), int(round(col)), row, col]
                if row_new >= 32:
                    row_new = 31
                elif row_new < 0:
                    row_new = 0
                if col_new >= 192:
                    col_new = 191
                elif col_new < 0:
                    col_new = 0

                goal = [row_new, col_new, row1, col1]
                goal_location = [goal]

                # image = image_dataset[i][0]
                # goal_prob = GoalPrediction.generate_gold_prob(goal, 32, 32 * 6)
                # goal_prob = goal_prob[:-1].view(32, 32 * 6)
                # image_flipped = image[:, :, :].swapaxes(0, 1).swapaxes(1, 2)
                # image_flipped = scipy.misc.imresize(image_flipped, (128 * 5, 128 * 6 * 5))
                # goal_map = goal_prob.cpu().data.numpy()
                # if np.sum(goal_map) > 0.01:
                #     for k in range(0, 32):
                #         for j in range(0, 32 * 6):
                #             if goal_map[k][j] < 0.01:
                #                 goal_map[k][j] = 0.0
                #         goal_map = scipy.misc.imresize(goal_map, (128 * 5, 128 * 6 * 5))
                # else:
                #     goal_map = None
                #
                # plt.imshow(image_flipped)
                # # if goal_map is not None:
                # #     plt.imshow(goal_map, cmap='jet', alpha=0.5)
                #
                # plt.title(instruction_to_string(data_point.instruction, config))
                # plt.savefig("./paper_figures/goal_" + str(i) + "_1.png")
                # plt.clf()

            # start_pos = current_pos_from_metadata(metadata)
            # start_pose = current_pose_from_metadata(metadata)
            # if row is not None and col is not None:
            #     goal_pos = data_point.get_destination_list()[-1]
            #     x_goal, z_goal = goal_pos
            #     height_drone = 2.5
            #     x_gen, z_gen = get_inverse_object_position(row, col, height_drone, 30, height, width,
            #                                                (start_pos[0], start_pos[1], start_pose))
            #     x_diff = x_gen - x_goal
            #     z_diff = z_gen - z_goal
            #     dist = math.sqrt(x_diff * x_diff + z_diff * z_diff)
            #     assert dist < 0.5, "forward computation of goal should match inverse computation"
            # else:
            #     print("Warning: None found! ")

            goal_dataset.append(goal_location)

        assert len(image_dataset) == len(dataset) and len(goal_dataset) == len(dataset)

        end = time.time()
        logging.info("Parsed dataset of size %r in time % seconds", len(image_dataset), (end - start))

        return image_dataset, goal_dataset

    @staticmethod
    def parse_oracle_turn(folder_name, dataset, model):

        start = time.time()

        num_channel, height, width = model.image_module.get_final_dimension()

        # Read images
        image_dataset = []
        num_examples = len(os.listdir(folder_name))
        for i in range(0, num_examples):
            data_point = dataset[i]

            ################################################
            pos = data_point.get_start_pos()
            metadata = {"x_pos": pos[0], "z_pos": pos[1], "y_angle": pos[2]}

            turn_angle = get_turn_angle_from_metadata_datapoint(metadata, data_point)

            assert 180.0 >= turn_angle >= -180.0

            if 30.0 >= turn_angle > -30.0:
                ix = 0
                mean_turn_angle = 0
            elif 90.0 >= turn_angle > 30.0:
                ix = 1
                mean_turn_angle = 60
            elif 150.0 >= turn_angle > 90.0:
                ix = 2
                mean_turn_angle = 120
            elif -30 >= turn_angle > -90.0:
                ix = 5
                mean_turn_angle = -60
            elif -90.0 >= turn_angle > -150.0:
                ix = 4
                mean_turn_angle = -120
            else:
                ix = 3
                mean_turn_angle = 180

            print("Pose is %r, Turn Angle is %r and Mean Turn Angle is %r " % (pos[2], turn_angle, mean_turn_angle))
            new_pos_angle = pos[2] + mean_turn_angle

            while new_pos_angle < -180:
                new_pos_angle += 360.0
            while new_pos_angle > 180:
                new_pos_angle -= 360.0

            # Modify the pos to turn towards the image
            new_pos = (pos[0], pos[1], new_pos_angle)
            data_point.start_pos = new_pos

            pos = data_point.get_start_pos()
            metadata = {"x_pos": pos[0], "z_pos": pos[1], "y_angle": pos[2]}
            new_turn_angle = get_turn_angle_from_metadata_datapoint(metadata, data_point)
            assert 30.0 >= new_turn_angle >= -30.0, "Found turn angle of " + str(new_turn_angle)
            ################################################

            example_folder_name = folder_name + "/example_" + str(i)
            img = scipy.misc.imread(example_folder_name + "/image_" + str(ix) + ".png").swapaxes(1, 2).swapaxes(0, 1)
            images = [img]
            image_dataset.append(images)

        assert len(image_dataset) == len(dataset)

        # Read the goal state. The data for the single image can be
        # directly computed and does not need to be saved.
        goal_dataset = []
        for i in range(0, num_examples):
            data_point = dataset[i]

            pos = data_point.get_start_pos()
            metadata = {"x_pos": pos[0], "z_pos": pos[1], "y_angle": pos[2]}

            goal_location = [GoalPrediction.get_goal_location(metadata, data_point, height, width)]
            _, _, row, col = goal_location[0]

            start_pos = current_pos_from_metadata(metadata)
            start_pose = current_pose_from_metadata(metadata)

            if row is not None and col is not None:
                goal_pos = data_point.get_destination_list()[-1]
                x_goal, z_goal = goal_pos
                height_drone = 2.5
                x_gen, z_gen = get_inverse_object_position(row, col, height_drone, 30, height, width,
                                                           (start_pos[0], start_pos[1], start_pose))
                x_diff = x_gen - x_goal
                z_diff = z_gen - z_goal
                dist = math.sqrt(x_diff * x_diff + z_diff * z_diff)
                assert dist < 0.5, "forward computation of goal should match inverse computation"
            else:
                print("Warning: None found! ")

            goal_dataset.append(goal_location)

        end = time.time()
        logging.info("Parsed dataset of size %r in time % seconds", len(image_dataset), (end - start))
        return image_dataset, goal_dataset

    def convert_to_id(self, instruction):
        tk_seq = instruction.split()
        token_ids = []
        for tk in tk_seq:
            if tk in self.vocab:
                token_ids.append(self.vocab[tk])
            else:
                print("Out of vocabulary word. Ignoring ", tk)
        return token_ids

    def is_close_enough(self, inferred_ix, row, col):
        predicted_row = int(inferred_ix / float(self.final_width))
        predicted_col = inferred_ix % self.final_width

        row_diff = row - predicted_row
        col_diff = col - predicted_col

        dist = math.sqrt(row_diff * row_diff + col_diff * col_diff)

        max_dim = float(max(self.final_height, self.final_width))
        if dist < 0.1 * max_dim:
            return True
        else:
            return False

    def compute_distance_in_real_world(self, inferred_ix, row_col, data_point, panaroma=True):

        if row_col is None:
            predicted_row = int(inferred_ix / float(self.final_width))
            predicted_col = inferred_ix % self.final_width
        else:
            predicted_row, predicted_col = row_col

        if panaroma:
            region_index = int(predicted_col / 32)
            predicted_col = predicted_col % 32
            pos = data_point.get_start_pos()
            new_pos_angle = GoalPredictionSingle360ImageSupervisedLearningFromDisk.\
                get_new_pos_angle_from_region_index(region_index, pos)
            metadata = {"x_pos": pos[0], "z_pos": pos[1], "y_angle": new_pos_angle}
        else:
            pos = data_point.get_start_pos()
            metadata = {"x_pos": pos[0], "z_pos": pos[1], "y_angle": pos[2]}

        if row_col is None:
            row, col = predicted_row + 0.5, predicted_col + 0.5
        else:
            row, col = predicted_row, predicted_col

        start_pos = current_pos_from_metadata(metadata)
        start_pose = current_pose_from_metadata(metadata)

        goal_pos = data_point.get_destination_list()[-1]
        height_drone = 2.5
        x_gen, z_gen = get_inverse_object_position(row, col, height_drone, 30, 32, 32,
                                                   (start_pos[0], start_pos[1], start_pose))
        x_goal, z_goal = goal_pos

        x_diff = x_gen - x_goal
        z_diff = z_gen - z_goal

        dist = math.sqrt(x_diff * x_diff + z_diff * z_diff)
        return (x_gen, z_gen), dist

    def get_inferred_value(self, volatile):

        if self.inference_procedure == GoalPredictionSingle360ImageSupervisedLearningFromDisk.MODE:
            # Mode setting
            inferred_ix = int(torch.max(volatile["attention_logits"], 0)[1].data.cpu().numpy()[0])

            return inferred_ix, None

        elif self.inference_procedure == GoalPredictionSingle360ImageSupervisedLearningFromDisk.MEAN:
            prob_values = volatile["attention_probs"][:-1].view(32, 192).data.cpu().numpy()
            expected_row = 0
            expected_col = 0
            for row in range(0, 32):
                for col in range(0, 192):
                    expected_row = expected_row + row * prob_values[row, col]
                    expected_col = expected_col + col * prob_values[row, col]

            mode_ix = int(torch.max(volatile["attention_logits"], 0)[1].data.cpu().numpy()[0])
            row_ = int(mode_ix/192)
            col_ = mode_ix % 192
            inferred_ix = expected_row * 192.0 + expected_col
            print("Expected Row is %r Mode Row is %r and and Expected Col is %r, Mode Col is %r "
                  % (expected_row, row_, expected_col, col_))
            if inferred_ix > 32 * 192:
                inferred_ix = 32 * 192

            return inferred_ix, None

        elif self.inference_procedure == GoalPredictionSingle360ImageSupervisedLearningFromDisk.MEANAROUNDMODE:

            mode_ix = int(torch.max(volatile["attention_logits"], 0)[1].data.cpu().numpy()[0])
            mode_row = int(mode_ix / 192)
            mode_col = mode_ix % 192

            expected_row = 0
            expected_col = 0
            prob_values = volatile["attention_probs"][:-1].view(32, 192).data.cpu().numpy()
            z = 0.0
            for i in range(0, 1):
                for j in range(-1, 2):
                    row = mode_row + i
                    col = mode_col + j
                    if row < 0 or row >= 32 or col < 0 or col >= 192:
                        continue
                    expected_row = expected_row + row * prob_values[row, col]
                    expected_col = expected_col + col * prob_values[row, col]
                    z = z + prob_values[row, col]

            # print("Prob Values is %r, Mode Row is %r, Mode Col is %r, Expected Row is %r, Expected Col is %r, Z is %r"
            #       % (prob_values[mode_row, mode_col], mode_row, mode_col, expected_row, expected_col, z))
            inferred_ix = (expected_row * 192.0 + expected_col)/z
            if inferred_ix > 32 * 192:
                inferred_ix = 32 * 192

            print("Predicted Inferred ix is %r, Was %r " % (inferred_ix, mode_ix))

            return inferred_ix, (expected_row + 0.5, expected_col)

        else:
            raise AssertionError("Not handled")

        return inferred_ix, None

    def save_attention_prob(self, image, attention_prob, instruction, goal_prob=None):
        self.global_id += 1

        image_flipped = image.swapaxes(0, 1).swapaxes(1, 2)
        image_flipped = scipy.misc.imresize(image_flipped, (128, 128 * 6))
        attention_prob = attention_prob.cpu().data.numpy()
        resized_kernel = scipy.misc.imresize(attention_prob, (128*5, 128*5 * 6))
        if goal_prob is not None:
            goal_location = goal_prob.cpu().data.numpy()
            if np.sum(goal_location) > 0.01:
                for i in range(0, 32):
                    for j in range(0, 192):
                            if goal_location[i][j] < 0.01:
                                goal_location[i][j] = 0.0
                goal_location = scipy.misc.imresize(goal_location, (128, 128 * 6))
        else:
            goal_location = None

        plt.title(instruction)
        plt.imshow(resized_kernel, cmap='jet', alpha=0.5)
        plt.savefig("./final_figures/image_" + str(self.global_id) + ".png")
        plt.clf()

        # f, axarr = plt.subplots(1, 2)
        # if instruction is not None:
        #     f.suptitle(instruction)
        # axarr[0].set_title("Predicted Attention")
        # # axarr[0].imshow(image_flipped)
        # axarr[0].imshow(resized_kernel, cmap='jet', alpha=0.5)
        # axarr[1].set_title("Gold Attention (Goal)")
        # axarr[1].imshow(image_flipped)
        # if goal_location is not None:
        #     axarr[1].imshow(goal_location, cmap='jet', alpha=0.5)
        # plt.savefig("./attention_prob/image_" + str(self.global_id) + ".png")
        # plt.clf()

    def show_image(self, goal, predicted_goal, start_pos, instruction):
        self.global_id += 1

        # image_flipped = image.swapaxes(0, 1).swapaxes(1, 2)
        # image_flipped = scipy.misc.imresize(image_flipped, (128, 128 * 6))
        goal_map = np.zeros((50, 50))
        predicted_goal_map = np.zeros((50, 50))

        x_1, y_1 = goal
        x_2, y_2 = predicted_goal
        x_3, y_3, _ = start_pos

        x_1 = min(x_1, 274.99)
        y_1 = min(y_1, 274.99)
        x_2 = min(x_2, 274.99)
        y_2 = min(y_2, 274.99)
        x_3 = min(x_3, 274.99)
        y_3 = min(y_3, 274.99)

        print(" %r %r %r %r " % (x_1, y_1, x_2, y_2))
        assert 225.0 <= x_1 <= 275.0
        assert 225.0 <= x_2 <= 275.0
        assert 225.0 <= x_3 <= 275.0
        assert 225.0 <= y_1 <= 275.0
        assert 225.0 <= y_2 <= 275.0
        assert 225.0 <= y_3 <= 275.0

        i1, j1 = int((x_1 - 225.0)), int((y_1 - 225.0))
        i2, j2 = int((x_2 - 225.0)), int((y_2 - 225.0))
        i3, j3 = int((x_3 - 225.0)), int((y_3 - 225.0))

        goal_map[i1, j1] = 1.0
        goal_map[i3, j3] = 0.75
        predicted_goal_map[i2, j2] = 1.0
        predicted_goal_map[i3, j3] = 0.75

        f, axarr = plt.subplots(1, 2)
        if instruction is not None:
            f.suptitle(instruction)
        axarr[0].set_title("Predicted Goal")
        # axarr[0].imshow(image_flipped)
        axarr[0].imshow(predicted_goal_map, cmap='jet', alpha=0.5)
        axarr[1].set_title("Gold Goal")
        # axarr[1].imshow(image_flipped)
        axarr[1].imshow(goal_map, cmap='jet', alpha=0.5)
        plt.savefig("./attention_prob/image_" + str(self.global_id) + "_maps.png")
        plt.clf()

    def interactive_shell(self, train_dataset, train_images):

        traj_len = len(train_dataset)
        keep = False
        image_id = 1
        while True:

            # Sample a random dataset
            if not keep:
                ix = random.randint(0, traj_len - 1)
            data_point = train_dataset[ix]
            image = train_images[ix][0]

            # Show the image in pyplot
            plt.imshow(image.swapaxes(0, 1).swapaxes(1, 2))
            plt.ion()
            plt.show()

            # Get the instruction
            print("Enter the instruction below (q or quit to quit)\n")
            print("Sample instruction is ", instruction_to_string(data_point.instruction, self.config))
            while True:
                instruction = input()
                if instruction == "q" or instruction == "quit":
                    break
                elif len(instruction) == 0:
                    print("Enter a non-empty instruction (q or quit to quit)")
                else:
                    break

            instruction_id = self.convert_to_id(instruction)
            state = AgentObservedState(instruction=instruction_id,
                                       config=self.config,
                                       constants=self.constants,
                                       start_image=image,
                                       previous_action=None,
                                       pose=None,
                                       position_orientation=None,
                                       data_point=data_point)

            # Show the attention mask
            _, _, _, volatile = self.model.get_attention_prob(state, model_state=None)

            attention_prob = volatile["attention_probs"][:-1].view(self.final_height, self.final_width)
            attention_prob = attention_prob.cpu().data.numpy()
            resized_kernel = scipy.misc.imresize(attention_prob,
                                                 (self.config["image_height"], self.config["image_width"]))
            plt.clf()
            plt.title(instruction)
            plt.imshow(image.swapaxes(0, 1).swapaxes(1, 2))
            plt.imshow(resized_kernel, cmap="jet", alpha=0.5)

            print("Enter s to save, k to keep working on this environment, sk to do both. Other key to simply continue")
            key_ = input()
            if key_ == "s":
                plt.savefig("interactive_image_" + str(image_id) + ".png")
                image_id += 1

            if key_ == "k":
                keep = True
            else:
                keep = False

            if key_ == "sk":
                plt.savefig("image_" + str(image_id) + ".png")
                image_id += 1
                keep = True

            plt.clf()

    def test(self, tune_dataset, tune_image, tune_goal_location, tensorboard):

        total_validation_loss = 0
        total_validation_prob = 0
        total_validation_exact_accuracy = 0
        total_goal_distance = 0
        num_items = 0

        # Next metric measures when the goal is visible and prediction is within 10\% radius
        total_epsilon_accuracy = 0
        num_visible_items = 0

        # Next metric measures distance in real world and only when goal is visible
        total_real_world_distance = 0

        correct = 0
        count_correct = 0

        for data_point_ix, data_point in enumerate(tune_dataset):
            tune_image_example = tune_image[data_point_ix]
            goal_location = tune_goal_location[data_point_ix]
            image = tune_image_example[0]

            model_state = None
            state = AgentObservedState(instruction=data_point.instruction,
                                       config=self.config,
                                       constants=self.constants,
                                       start_image=image,
                                       previous_action=None,
                                       pose=None,
                                       position_orientation=data_point.get_start_pos(),
                                       data_point=data_point)

            num_items_ = 0
            sum_loss = 0
            sum_prob = 0
            sum_acc = 0
            sum_dist = 0
            sum_real_world_distance = 0

            goal = goal_location[0]
            state.goal = goal
            volatile = self.model.get_attention_prob(state, model_state)
            row, col, _, _ = goal

            if not self.ignore_none or row is not None:
                if row is None or col is None:
                    gold_ix = self.final_height * self.final_width
                else:
                    gold_ix = row * self.final_width + col
                loss, prob, meta = GoalPrediction.get_loss_and_prob(
                    volatile, goal, self.final_height, self.final_width)
                num_items_ += 1
                sum_loss = sum_loss + float(loss.data.cpu().numpy()[0])
                sum_prob = sum_prob + float(prob.data.cpu().numpy()[0])

                inferred_ix, row_col = self.get_inferred_value(volatile)
                # Center pixel prediction
                # inferred_ix, row_col = 20 * 192 + 32 * 3 + 16, None

                if gold_ix == inferred_ix:
                    sum_acc = sum_acc + 1.0
                if row is not None and col is not None:
                    sum_dist = sum_dist + abs(row - int(round(inferred_ix/self.final_width)))\
                               + abs(col - int(inferred_ix % self.final_height))
                    num_visible_items += 1
                    if self.is_close_enough(inferred_ix, row, col):
                        total_epsilon_accuracy += 1
                    predicted_goal, real_world_distance = self.compute_distance_in_real_world(inferred_ix, row_col, data_point)
                    sum_real_world_distance += real_world_distance

                    count_correct += 1.0
                    if real_world_distance <= 5.0:
                        correct += 1.0

                    # # Save the map
                    # instruction_string = instruction_to_string(data_point.instruction, self.config)
                    # goal_x, goal_y = data_point.get_destination_list()[-1]
                    # goal_x, goal_y = round(goal_x, 2), round(goal_y, 2)
                    # predicted_goal_x, predicted_goal_y = predicted_goal
                    # predicted_goal_x, predicted_goal_y = round(predicted_goal_x, 2), round(predicted_goal_y, 2)
                    # instruction_string = instruction_string + \
                    #                      "\n (Error: " + str(round(sum_real_world_distance, 2)) + ")" + \
                    #                      "\n %r %r %r %r \n" % (goal_x, goal_y, predicted_goal_x, predicted_goal_y)
                    # self.show_image(data_point.get_destination_list()[-1], predicted_goal, data_point.get_start_pos(),
                    #                 instruction_string)
                    #
                    # # Save the generated image
                    # goal_prob = GoalPrediction.generate_gold_prob(goal, 32, 192)
                    # predicted_goal = (int(inferred_ix/192), inferred_ix % 192, int(inferred_ix/192), inferred_ix % 192)
                    # predicted_goal_prob = GoalPrediction.generate_gold_prob(predicted_goal, 32, 192)
                    # self.save_attention_prob(image, volatile["attention_probs"][:-1].view(32, 192),
                    #                          instruction_string, goal_prob[:-1].view(32, 192))
                    # self.save_attention_prob(image, predicted_goal_prob[:-1].view(32, 192),
                    #                          instruction_string, goal_prob[:-1].view(32, 192))

            total_validation_loss += sum_loss
            total_validation_prob += sum_prob
            total_goal_distance += sum_dist
            total_validation_exact_accuracy += sum_acc
            total_real_world_distance += sum_real_world_distance
            num_items += num_items_

        mean_total_goal_distance = total_goal_distance / float(max(num_items, 1))
        mean_total_validation_loss = total_validation_loss / float(max(num_items, 1))
        mean_total_validation_prob = total_validation_prob / float(max(num_items, 1))
        mean_total_validation_accuracy = (total_validation_exact_accuracy * 100.0) / float(max(num_items, 1))
        mean_total_epsilon_accuracy = (total_epsilon_accuracy * 100.0) / float(max(num_visible_items, 1))
        mean_real_world_distance = total_real_world_distance / float(max(num_visible_items, 1))

        logging.info("Mean Test result: L1 Distance is %r, Loss %r, Prob %r, Acc is %r, Epsilon Accuracy is %r"
                     % (mean_total_goal_distance, mean_total_validation_loss, mean_total_validation_prob,
                        mean_total_validation_accuracy, mean_total_epsilon_accuracy))
        logging.info("Num visible items %r, Num Exact Match items is %r, Num epsilon match %r, Num Items is %r "
                     % (num_visible_items, total_validation_exact_accuracy, total_epsilon_accuracy, num_items))
        logging.info("Num visible items %r, Mean Real World Distance %r "
                     % (num_visible_items, mean_real_world_distance))
        logging.info("Num counts %r, Task Completion Accuracy %r "
                     % (count_correct, (correct * 100.0)/float(max(1, count_correct))))


        return mean_real_world_distance

    def do_train(self, train_dataset, train_images, train_goal_location,
                 tune_dataset, tune_images, tune_goal_location, experiment_name, save_best_model=False):
        """ Perform training """

        dataset_size = len(train_dataset)
        tensorboard = self.tensorboard

        # Test on tuning data with initialized model
        mean_real_world_distance = self.test(tune_dataset, tune_images, tune_goal_location, tensorboard=tensorboard)
        best_real_world_distance = mean_real_world_distance

        for epoch in range(1, self.max_epoch + 1):

            logging.info("Starting epoch %d", epoch)

            batch_replay_items = []
            best_real_world_distance = min(best_real_world_distance, mean_real_world_distance)

            for data_point_ix, data_point in enumerate(train_dataset):

                if (data_point_ix + 1) % 100 == 0:
                    logging.info("Done %d out of %d", data_point_ix, dataset_size)

                train_images_example = train_images[data_point_ix]
                goal_location = train_goal_location[data_point_ix]
                image = train_images_example[0]

                model_state = None
                state = AgentObservedState(instruction=data_point.instruction,
                                           config=self.config,
                                           constants=self.constants,
                                           start_image=image,
                                           previous_action=None,
                                           pose=None,
                                           position_orientation=data_point.get_start_pos(),
                                           data_point=data_point)

                # Generate attention probabilities
                volatile = self.model.get_attention_prob(state, model_state)
                goal = goal_location[0]

                # Store it in the replay memory list
                if not self.ignore_none or goal[0] is not None:
                    replay_item = ReplayMemoryItem(state, None, 0, volatile=volatile, goal=goal)
                    batch_replay_items.append(replay_item)

                # Perform update
                if len(batch_replay_items) > 0:
                    loss_val = self.do_update(batch_replay_items)
                    batch_replay_items = []
                    if tensorboard is not None:
                        tensorboard.log_scalar("Loss", loss_val)
                        if self.goal_prediction_loss is not None:
                            goal_prediction_loss = float(self.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss", goal_prediction_loss)
                        if self.goal_prob is not None:
                            goal_prob = float(self.goal_prob.data[0])
                            tensorboard.log_scalar("goal_prob", goal_prob)
                        if self.object_detection_loss is not None:
                            object_detection_loss = float(self.object_detection_loss.data[0])
                            tensorboard.log_scalar("object_detection_loss", object_detection_loss)
                        if self.cross_entropy_loss is not None:
                            cross_entropy_loss = float(self.cross_entropy_loss.data[0])
                            tensorboard.log_scalar("Cross_entropy_loss", cross_entropy_loss)
                        if self.dist_loss is not None:
                            dist_loss = float(self.dist_loss.data[0])
                            tensorboard.log_scalar("Dist_loss", dist_loss)

            mean_real_world_distance = self.test(tune_dataset, tune_images, tune_goal_location, tensorboard=tensorboard)

            # Save the model
            if save_best_model:
                if mean_real_world_distance < best_real_world_distance:
                    self.model.save_model(experiment_name + "/goal_prediction_single_supervised_epoch_" + str(epoch))
            else:
                self.model.save_model(experiment_name + "/goal_prediction_single_supervised_epoch_" + str(epoch))