class ModelPolicyNetworkResnetWithStop(AbstractModel):
    def __init__(self, config, constants):
        AbstractModel.__init__(self, config, constants)
        self.none_action = config["num_actions"]
        self.image_module = ImageResnetModule(
            image_emb_size=constants["image_emb_dim"],
            input_num_channels=3*constants["max_num_images"],
            image_height=config["image_height"],
            image_width=config["image_width"])
        if config["use_pointer_model"]:
            self.text_module = TextPointerModule(
                emb_dim=constants["word_emb_dim"],
                hidden_dim=constants["lstm_emb_dim"],
                vocab_size=config["vocab_size"])
        else:
            self.text_module = TextSimpleModule(
                emb_dim=constants["word_emb_dim"],
                hidden_dim=constants["lstm_emb_dim"],
                vocab_size=config["vocab_size"])
        self.action_module = ActionSimpleModule(
            num_actions=config["num_actions"],
            action_emb_size=constants["action_emb_dim"])
        total_emb_size = (constants["image_emb_dim"]
                          + constants["lstm_emb_dim"]
                          + constants["action_emb_dim"])
        final_module = MultimodalSimpleWithStopModule(
            image_module=self.image_module,
            text_module=self.text_module,
            action_module=self.action_module,
            total_emb_size=total_emb_size,
            num_actions=config["num_actions"])
        self.final_module = final_module
        if torch.cuda.is_available():
            self.image_module.cuda()
            self.text_module.cuda()
            self.action_module.cuda()
            self.final_module.cuda()

    def get_probs_batch(self, agent_observed_state_list, mode=None):
        for aos in agent_observed_state_list:
            assert isinstance(aos, AgentObservedState)
        # print "batch size:", len(agent_observed_state_list)

        # sort list by instruction length
        agent_observed_state_list = sorted(
            agent_observed_state_list,
            key=lambda aos_: len(aos_.get_instruction()),
            reverse=True
        )

        images = [aos.get_image() for aos in agent_observed_state_list]
        image_batch = cuda_var(torch.from_numpy(np.array(images)).float())

        instructions = [aos.get_instruction()
                        for aos in agent_observed_state_list]
        read_pointers = [aos.get_read_pointers()
                         for aos in agent_observed_state_list]
        instructions_batch = (instructions, read_pointers)

        prev_actions_raw = [aos.get_previous_action()
                            for aos in agent_observed_state_list]
        prev_actions = [self.none_action if a is None else a
                        for a in prev_actions_raw]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)))

        probs_batch = self.final_module(image_batch, instructions_batch,
                                        prev_actions_batch, mode)
        return probs_batch

    def load_saved_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_)
        image_module_path = os.path.join(load_dir, "image_module_state.bin")
        self.image_module.load_state_dict(torch_load(image_module_path))
        text_module_path = os.path.join(load_dir, "text_module_state.bin")
        self.text_module.load_state_dict(torch_load(text_module_path))
        action_module_path = os.path.join(load_dir, "action_module_state.bin")
        self.action_module.load_state_dict(torch_load(action_module_path))
        final_module_path = os.path.join(load_dir, "final_module_state.bin")
        self.final_module.load_state_dict(torch_load(final_module_path))

    def save_model(self, save_dir):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # save state file for image nn
        image_module_path = os.path.join(save_dir, "image_module_state.bin")
        torch.save(self.image_module.state_dict(), image_module_path)
        # save state file for text nn
        text_module_path = os.path.join(save_dir, "text_module_state.bin")
        torch.save(self.text_module.state_dict(), text_module_path)
        # save state file for action emb
        action_module_path = os.path.join(save_dir, "action_module_state.bin")
        torch.save(self.action_module.state_dict(), action_module_path)
        # save state file for final nn
        final_module_path = os.path.join(save_dir, "final_module_state.bin")
        torch.save(self.final_module.state_dict(), final_module_path)

    def get_parameters(self):
        parameters = list(self.image_module.parameters())
        parameters += list(self.text_module.parameters())
        parameters += list(self.action_module.parameters())
        parameters += list(self.final_module.parameters())
        return parameters
class IncrementalModelEmnlp(AbstractIncrementalModel):
    def __init__(self, config, constants):
        AbstractIncrementalModel.__init__(self, config, constants)
        self.none_action = config["num_actions"]

        self.config = config
        self.constants = constants

        # CNN over images - using SimpleImage for testing for now!
        self.image_module = ImageCnnEmnlp(
            image_emb_size=config["image_emb_dim"],
            input_num_channels=3 *
            5,  #3 channels per image - 5 images in history
            image_height=config["image_height"],
            image_width=config["image_width"])

        # this is somewhat counter intuitivie - emb_dim is the word size
        # hidden_size is the output size
        self.text_module = TextSimpleModule(emb_dim=config["word_emb_dim"],
                                            hidden_dim=config["lstm_emb_dim"],
                                            vocab_size=config["vocab_size"])

        self.previous_action_module = ActionSimpleModule(
            num_actions=config["no_actions"],
            action_emb_size=config["previous_action_embedding_dim"])

        self.previous_block_module = ActionSimpleModule(
            num_actions=config["no_blocks"],
            action_emb_size=config["previous_block_embedding_dim"])

        self.final_module = IncrementalMultimodalEmnlp(
            image_module=self.image_module,
            text_module=self.text_module,
            previous_action_module=self.previous_action_module,
            previous_block_module=self.previous_block_module,
            input_embedding_size=config["lstm_emb_dim"] +
            config["image_emb_dim"] + config["previous_action_embedding_dim"] +
            config["previous_block_embedding_dim"],
            output_hidden_size=config["h1_hidden_dim"],
            blocks_hidden_size=config["no_blocks"],
            directions_hidden_size=config["no_actions"],
            max_episode_length=(constants["horizon"] + 5))

        if torch.cuda.is_available():
            self.image_module.cuda()
            self.text_module.cuda()
            self.previous_action_module.cuda()
            self.previous_block_module.cuda()
            self.final_module.cuda()

    def get_probs_batch(self, agent_observed_state_list, mode=None):
        raise NotImplementedError()

    def get_probs(self,
                  agent_observed_state,
                  model_state,
                  mode=None,
                  volatile=False):

        assert isinstance(agent_observed_state, AgentObservedState)

        #supposedly this is already padded with zeros, but i need to double check that code
        images = agent_observed_state.get_image()[-5:]

        # image_seqs = [[aos.get_last_image()]
        #               for aos in agent_observed_state_list]
        image_batch = cuda_var(
            torch.from_numpy(np.array(images)).float(), volatile)

        #flatten them? TODO: maybe don't hardcode this later on? batch size is 1 ;)
        image_batch = image_batch.view(1, 15, 128, 128)

        # list of list :)
        instructions_batch = ([agent_observed_state.get_instruction()], False)
        #instructions_batch = (cuda_var(torch.from_numpy(np.array(instructions)).long()), False)

        #print("instructions", instructions)
        #print("instructins_batch", instructions_batch)

        prev_actions_raw = agent_observed_state.get_previous_action()
        prev_actions_raw = self.none_action if prev_actions_raw is None else prev_actions_raw

        if prev_actions_raw == 81:
            previous_direction_id = [4]
        else:
            previous_direction_id = [prev_actions_raw % 4]
        #this input is is over the space 81 things :)
        previous_block_id = [int(prev_actions_raw / 4)]

        prev_block_id_batch = cuda_var(
            torch.from_numpy(np.array(previous_block_id)))
        prev_direction_id_batch = cuda_var(
            torch.from_numpy(np.array(previous_direction_id)))

        # prev_actions = [self.none_action if a is None else a
        #                 for a in prev_actions_raw]
        #prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)))

        probs_batch, new_model_state = self.final_module(
            image_batch, instructions_batch, prev_block_id_batch,
            prev_direction_id_batch, model_state)

        # last two we don't really need...
        return probs_batch, new_model_state, None, None

    def init_weights(self):
        self.text_module.init_weights()
        self.image_module.init_weights()
        self.previous_action_module.init_weights()
        self.previous_block_module.init_weights()
        self.final_module.init_weights()

    def share_memory(self):
        self.image_module.share_memory()
        self.text_module.share_memory()
        self.previous_action_module.share_memory()
        self.previous_block_module.share_memory()
        self.final_module.share_memory()

    def get_state_dict(self):
        nested_state_dict = dict()
        nested_state_dict["image_module"] = self.image_module.state_dict()
        nested_state_dict["text_module"] = self.text_module.state_dict()
        nested_state_dict[
            "previous_action_module"] = self.previous_action_module.state_dict(
            )
        nested_state_dict[
            "previous_block_module"] = self.previous_block_module.state_dict()
        nested_state_dict["final_module"] = self.final_module.state_dict()

        return nested_state_dict

    def load_from_state_dict(self, nested_state_dict):
        self.image_module.load_state_dict(nested_state_dict["image_module"])
        self.text_module.load_state_dict(nested_state_dict["text_module"])
        self.previous_action_module.load_state_dict(
            nested_state_dict["previous_action_module"])
        self.previous_block_module.load_state_dict(
            nested_state_dict["previous_block_module"])
        self.final_module.load_state_dict(nested_state_dict["final_module"])

    def load_resnet_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_,
                                               map_location=lambda s_, l_: s_)
        image_module_path = os.path.join(load_dir, "image_module_state.bin")
        self.image_module.load_state_dict(torch_load(image_module_path))

    def load_lstm_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_,
                                               map_location=lambda s_, l_: s_)
        text_module_path = os.path.join(load_dir, "text_module_state.bin")
        self.text_module.load_state_dict(torch_load(text_module_path))

    def load_saved_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_,
                                               map_location=lambda s_, l_: s_)
        image_module_path = os.path.join(load_dir, "image_module_state.bin")
        self.image_module.load_state_dict(torch_load(image_module_path))

        previous_action_module_path = os.path.join(
            load_dir, "previous_action_module_state.bin")
        self.previous_action_module.load_state_dict(
            torch_load(previous_action_module_path))

        previous_block_module_path = os.path.join(
            load_dir, "previous_block_module_state.bin")
        self.previous_block_module.load_state_dict(
            torch_load(previous_block_module_path))

        text_module_path = os.path.join(load_dir, "text_module_state.bin")
        self.text_module.load_state_dict(torch_load(text_module_path))
        # action_module_path = os.path.join(load_dir, "action_module_state.bin")
        # self.action_module.load_state_dict(torch_load(action_module_path))
        final_module_path = os.path.join(load_dir, "final_module_state.bin")
        self.final_module.load_state_dict(torch_load(final_module_path))

    def save_model(self, save_dir):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # save state file for image nn
        image_module_path = os.path.join(save_dir, "image_module_state.bin")
        torch.save(self.image_module.state_dict(), image_module_path)
        # save state file for image recurrence nn

        previous_action_module_path = os.path.join(
            save_dir, "previous_action_module_state.bin")
        torch.save(self.previous_action_module.state_dict(),
                   previous_action_module_path)

        previous_block_module_path = os.path.join(
            save_dir, "previous_block_module_state.bin")
        torch.save(self.previous_block_module.state_dict(),
                   previous_block_module_path)

        # save state file for text nn
        text_module_path = os.path.join(save_dir, "text_module_state.bin")
        torch.save(self.text_module.state_dict(), text_module_path)
        # save state file for action emb
        # action_module_path = os.path.join(save_dir, "action_module_state.bin")
        # torch.save(self.action_module.state_dict(), action_module_path)
        # save state file for final nn
        final_module_path = os.path.join(save_dir, "final_module_state.bin")
        torch.save(self.final_module.state_dict(), final_module_path)

    def get_parameters(self):
        # parameters = list(self.image_module.parameters())
        # parameters += list(self.action_module.parameters())
        # parameters += list(self.text_module.parameters())
        parameters = list(self.final_module.parameters())

        return parameters

    def get_named_parameters(self):
        # named_parameters = list(self.image_module.named_parameters())
        # named_parameters += list(self.action_module.named_parameters())
        # named_parameters += list(self.text_module.named_parameters())
        named_parameters = list(self.final_module.named_parameters())
        return named_parameters
Exemplo n.º 3
0
class ModelPolicyNetworkSymbolic(AbstractModel):
    def __init__(self, config, constants):
        AbstractModel.__init__(self, config, constants)
        self.none_action = config["num_actions"]
        landmark_names = get_all_landmark_names()
        self.radius_module = RadiusModule(15)
        self.angle_module = AngleModule(48)
        self.landmark_module = LandmarkModule(63)
        self.image_module = SymbolicImageModule(
            landmark_names=landmark_names,
            radius_module=self.radius_module,
            angle_module=self.angle_module,
            landmark_module=self.landmark_module)
        if config["use_pointer_model"]:
            self.text_module = TextPointerModule(
                emb_dim=constants["word_emb_dim"],
                hidden_dim=constants["lstm_emb_dim"],
                vocab_size=config["vocab_size"])
        else:
            self.text_module = TextSimpleModule(
                emb_dim=constants["word_emb_dim"],
                hidden_dim=constants["lstm_emb_dim"],
                vocab_size=config["vocab_size"])
        self.action_module = ActionSimpleModule(
            num_actions=config["num_actions"],
            action_emb_size=constants["action_emb_dim"])
        total_emb_size = (32 * 3 * 63
                          + constants["lstm_emb_dim"]
                          + constants["action_emb_dim"])
        final_module = MultimodalSimpleModule(
            image_module=self.image_module,
            text_module=self.text_module,
            action_module=self.action_module,
            total_emb_size=total_emb_size,
            num_actions=config["num_actions"])
        self.final_module = final_module
        if torch.cuda.is_available():
            self.image_module.cuda()
            self.text_module.cuda()
            self.action_module.cuda()
            self.final_module.cuda()
            self.radius_module.cuda()
            self.angle_module.cuda()
            self.landmark_module.cuda()

    def get_probs_batch(self, agent_observed_state_list, mode=None):
        for aos in agent_observed_state_list:
            assert isinstance(aos, AgentObservedState)
        # print "batch size:", len(agent_observed_state_list)

        # sort list by instruction length
        agent_observed_state_list = sorted(
            agent_observed_state_list,
            key=lambda aos_: len(aos_.get_instruction()),
            reverse=True
        )

        symbolic_image_list = []
        for aos in agent_observed_state_list:
            x_pos, z_pos, y_angle = aos.get_position_orientation()
            landmark_pos_dict = aos.get_landmark_pos_dict()
            symbolic_image = get_visible_landmark_r_theta(
                x_pos, z_pos, y_angle, landmark_pos_dict)
            symbolic_image_list.append(symbolic_image)
        image_batch = symbolic_image_list

        instructions = [aos.get_instruction()
                        for aos in agent_observed_state_list]
        read_pointers = [aos.get_read_pointers()
                         for aos in agent_observed_state_list]
        instructions_batch = (instructions, read_pointers)

        prev_actions_raw = [aos.get_previous_action()
                            for aos in agent_observed_state_list]
        prev_actions = [self.none_action if a is None else a
                        for a in prev_actions_raw]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)))

        probs_batch = self.final_module(image_batch, instructions_batch,
                                        prev_actions_batch, mode)
        return probs_batch

    def load_saved_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_)
        image_module_path = os.path.join(load_dir, "image_module_state.bin")
        self.image_module.load_state_dict(torch_load(image_module_path))
        text_module_path = os.path.join(load_dir, "text_module_state.bin")
        self.text_module.load_state_dict(torch_load(text_module_path))
        action_module_path = os.path.join(load_dir, "action_module_state.bin")
        self.action_module.load_state_dict(torch_load(action_module_path))
        final_module_path = os.path.join(load_dir, "final_module_state.bin")
        self.final_module.load_state_dict(torch_load(final_module_path))
        self.final_module.load_state_dict(torch_load(final_module_path))
        radius_module_path = os.path.join(load_dir, "radius_module_state.bin")
        self.radius_module.load_state_dict(torch_load(radius_module_path))
        angle_module_path = os.path.join(load_dir, "angle_module_state.bin")
        self.angle_module.load_state_dict(torch_load(angle_module_path))
        landmark_module_path = os.path.join(load_dir, "landmark_module_state.bin")
        self.landmark_module.load_state_dict(torch_load(landmark_module_path))

    def save_model(self, save_dir):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # save state file for image nn
        image_module_path = os.path.join(save_dir, "image_module_state.bin")
        torch.save(self.image_module.state_dict(), image_module_path)
        # save state file for text nn
        text_module_path = os.path.join(save_dir, "text_module_state.bin")
        torch.save(self.text_module.state_dict(), text_module_path)
        # save state file for action emb
        action_module_path = os.path.join(save_dir, "action_module_state.bin")
        torch.save(self.action_module.state_dict(), action_module_path)
        # save state file for final nn
        final_module_path = os.path.join(save_dir, "final_module_state.bin")
        torch.save(self.final_module.state_dict(), final_module_path)

    def get_parameters(self):
        parameters = list(self.image_module.parameters())
        parameters += list(self.text_module.parameters())
        parameters += list(self.action_module.parameters())
        parameters += list(self.final_module.parameters())
        return parameters
Exemplo n.º 4
0
class IncrementalModelRecurrentPolicyNetworkResnet(AbstractIncrementalModel):
    def __init__(self, config, constants):
        AbstractIncrementalModel.__init__(self, config, constants)
        self.none_action = config["num_actions"]
        self.image_module = ImageResnetModule(
            image_emb_size=constants["image_emb_dim"],
            input_num_channels=3,
            image_height=config["image_height"],
            image_width=config["image_width"],
            using_recurrence=True)
        self.num_cameras = 1
        self.image_recurrence_module = IncrementalRecurrenceSimpleModule(
            input_emb_dim=(constants["image_emb_dim"] * self.num_cameras + constants["action_emb_dim"]),
            output_emb_dim=constants["image_emb_dim"])
        if config["use_pointer_model"]:
            self.text_module = TextPointerModule(
                emb_dim=constants["word_emb_dim"],
                hidden_dim=constants["lstm_emb_dim"],
                vocab_size=config["vocab_size"])
        else:
            self.text_module = TextBiLSTMModule(
                emb_dim=constants["word_emb_dim"],
                hidden_dim=constants["lstm_emb_dim"],
                vocab_size=config["vocab_size"])
        self.action_module = ActionSimpleModule(
            num_actions=config["num_actions"],
            action_emb_size=constants["action_emb_dim"])
        if config["use_pointer_model"]:
            total_emb_size = (constants["image_emb_dim"]
                              + 4 * constants["lstm_emb_dim"]
                              + constants["action_emb_dim"])
        else:
            total_emb_size = ((self.num_cameras + 1) * constants["image_emb_dim"]
                              + 2 * constants["lstm_emb_dim"]
                              + constants["action_emb_dim"])

        if config["do_action_prediction"]:
            self.action_prediction_module = ActionPredictionModule(
                2 * self.num_cameras * constants["image_emb_dim"], constants["image_emb_dim"], config["num_actions"])
        else:
            self.action_prediction_module = None

        if config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_module = TemporalAutoencoderModule(
                self.action_module, self.num_cameras * constants["image_emb_dim"],
                constants["action_emb_dim"], constants["image_emb_dim"])
        else:
            self.temporal_autoencoder_module = None

        if config["do_object_detection"]:
            self.landmark_names = get_all_landmark_names()
            self.object_detection_module = ObjectDetectionModule(
                image_module=self.image_module, image_emb_size=self.num_cameras * constants["image_emb_dim"], num_objects=67)
        else:
            self.object_detection_module = None

        if config["do_symbolic_language_prediction"]:
            self.symbolic_language_prediction_module = SymbolicLanguagePredictionModule(
                total_emb_size=2 * constants["lstm_emb_dim"])
        else:
            self.symbolic_language_prediction_module = None

        if config["do_goal_prediction"]:
            self.goal_prediction_module = GoalPredictionModule(
                total_emb_size=32)
        else:
            self.goal_prediction_module = None

        final_module = TmpIncrementalMultimodalDenseValtsRecurrentSimpleModule(
            image_module=self.image_module,
            image_recurrence_module=self.image_recurrence_module,
            text_module=self.text_module,
            action_module=self.action_module,
            total_emb_size=total_emb_size,
            num_actions=config["num_actions"])
        self.final_module = final_module
        if torch.cuda.is_available():
            self.image_module.cuda()
            self.image_recurrence_module.cuda()
            self.text_module.cuda()
            self.action_module.cuda()
            self.final_module.cuda()
            if self.action_prediction_module is not None:
                self.action_prediction_module.cuda()
            if self.temporal_autoencoder_module is not None:
                self.temporal_autoencoder_module.cuda()
            if self.object_detection_module is not None:
                self.object_detection_module.cuda()
            if self.symbolic_language_prediction_module is not None:
                self.symbolic_language_prediction_module.cuda()
            if self.goal_prediction_module is not None:
                self.goal_prediction_module.cuda()

    def get_probs_batch(self, agent_observed_state_list, mode=None):
        raise AssertionError("Buggy")
        for aos in agent_observed_state_list:
            assert isinstance(aos, AgentObservedState)
        # print "batch size:", len(agent_observed_state_list)

        # sort list by instruction length
        agent_observed_state_list = sorted(
            agent_observed_state_list,
            key=lambda aos_: len(aos_.get_instruction()),
            reverse=True
        )

        image_seq_lens = [aos.get_num_images()
                          for aos in agent_observed_state_list]
        image_seq_lens_batch = cuda_tensor(
            torch.from_numpy(np.array(image_seq_lens)))
        max_len = max(image_seq_lens)
        image_seqs = [aos.get_image()[:max_len]
                      for aos in agent_observed_state_list]
        image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float())

        instructions = [aos.get_instruction()
                        for aos in agent_observed_state_list]
        read_pointers = [aos.get_read_pointers()
                         for aos in agent_observed_state_list]
        instructions_batch = (instructions, read_pointers)

        prev_actions_raw = [aos.get_previous_action()
                            for aos in agent_observed_state_list]
        prev_actions = [self.none_action if a is None else a
                        for a in prev_actions_raw]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)))

        probs_batch, _ = self.final_module(image_batch, image_seq_lens_batch,
                                           instructions_batch, prev_actions_batch,
                                           mode, model_state=None)
        return probs_batch

    def get_probs(self, agent_observed_state, model_state, mode=None, volatile=False):

        assert isinstance(agent_observed_state, AgentObservedState)
        agent_observed_state_list = [agent_observed_state]

        image_seq_lens = [1]
        image_seq_lens_batch = cuda_tensor(
            torch.from_numpy(np.array(image_seq_lens)))
        # max_len = max(image_seq_lens)
        # image_seqs = [aos.get_image()[:max_len]
        #               for aos in agent_observed_state_list]
        image_seqs = [[aos.get_last_image()]
                      for aos in agent_observed_state_list]
        image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float(), volatile)

        instructions = [aos.get_instruction()
                        for aos in agent_observed_state_list]
        read_pointers = [aos.get_read_pointers()
                         for aos in agent_observed_state_list]
        instructions_batch = (instructions, read_pointers)

        prev_actions_raw = [aos.get_previous_action()
                            for aos in agent_observed_state_list]
        prev_actions = [self.none_action if a is None else a
                        for a in prev_actions_raw]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)), volatile)

        probs_batch, new_model_state, image_emb_seq, state_feature = self.final_module(
            image_batch, image_seq_lens_batch, instructions_batch, prev_actions_batch, mode, model_state)
        return probs_batch, new_model_state, image_emb_seq, state_feature

    def action_prediction_log_prob(self, batch_input):
        assert self.action_prediction_module is not None, "Action prediction module not created. Check config."
        return self.action_prediction_module(batch_input)

    def predict_action_result(self, batch_image_feature, action_batch):
        assert self.temporal_autoencoder_module is not None, "Temporal action module not created. Check config."
        return self.temporal_autoencoder_module(batch_image_feature, action_batch)

    def predict_goal_result(self, batch_state_feature):
        assert self.goal_prediction_module is not None, "Goal Prediction module not created. Check config."
        return self.goal_prediction_module(batch_state_feature)

    def get_probs_and_visible_objects(self, agent_observed_state_list, batch_image_feature):
        assert self.object_detection_module is not None, "Object detection module not created. Check config."
        landmarks_visible = []
        for aos in agent_observed_state_list:
            x_pos, z_pos, y_angle = aos.get_position_orientation()
            landmark_pos_dict = aos.get_landmark_pos_dict()
            visible_landmarks_dict = self.object_detection_module.get_visible_landmark_r_theta(
                x_pos, z_pos, y_angle, landmark_pos_dict)
            landmarks_visible.append(visible_landmarks_dict)

        # shape is BATCH_SIZE x num objects x 2
        landmark_log_prob, distance_log_prob, theta_log_prob = self.object_detection_module(batch_image_feature)

        # landmarks_visible is list of length BATCH_SIZE, each item is a set containing landmark indices
        return landmark_log_prob, distance_log_prob, theta_log_prob, landmarks_visible

    def get_language_prediction_probs(self, batch_input):
        assert self.symbolic_language_prediction_module is not None, \
            "Language prediction module not created. Check config."
        return self.symbolic_language_prediction_module(batch_input)

    def init_weights(self):

        self.text_module.init_weights()
        self.image_recurrence_module.init_weights()
        self.image_module.init_weights()

    def share_memory(self):
        self.image_module.share_memory()
        self.image_recurrence_module.share_memory()
        self.text_module.share_memory()
        self.action_module.share_memory()
        self.final_module.share_memory()
        if self.action_prediction_module is not None:
            self.action_prediction_module.share_memory()
        if self.temporal_autoencoder_module is not None:
            self.temporal_autoencoder_module.share_memory()
        if self.object_detection_module is not None:
            self.object_detection_module.share_memory()
        if self.symbolic_language_prediction_module is not None:
            self.symbolic_language_prediction_module.share_memory()
        if self.goal_prediction_module is not None:
            self.goal_prediction_module.share_memory()

    def get_state_dict(self):
        nested_state_dict = dict()
        nested_state_dict["image_module"] = self.image_module.state_dict()
        nested_state_dict["image_recurrence_module"] = self.image_recurrence_module.state_dict()
        nested_state_dict["text_module"] = self.text_module.state_dict()
        nested_state_dict["action_module"] = self.action_module.state_dict()
        nested_state_dict["final_module"] = self.final_module.state_dict()
        if self.action_prediction_module is not None:
            nested_state_dict["ap_module"] = self.action_prediction_module.state_dict()
        if self.temporal_autoencoder_module is not None:
            nested_state_dict["tae_module"] = self.temporal_autoencoder_module.state_dict()
        if self.object_detection_module is not None:
            nested_state_dict["od_module"] = self.object_detection_module.state_dict()
        if self.symbolic_language_prediction_module is not None:
            nested_state_dict["sym_lang_module"] = self.symbolic_language_prediction_module.state_dict()
        if self.goal_prediction_module is not None:
            nested_state_dict["goal_pred_module"] = self.goal_prediction_module.state_dict()
        return nested_state_dict

    def load_from_state_dict(self, nested_state_dict):
        self.image_module.load_state_dict(nested_state_dict["image_module"])
        self.image_recurrence_module.load_state_dict(nested_state_dict["image_recurrence_module"])
        self.text_module.load_state_dict(nested_state_dict["text_module"])
        self.action_module.load_state_dict(nested_state_dict["action_module"])
        self.final_module.load_state_dict(nested_state_dict["final_module"])

        if self.action_prediction_module is not None:
            self.action_prediction_module.load_state_dict(nested_state_dict["ap_module"])
        if self.temporal_autoencoder_module is not None:
            self.temporal_autoencoder_module.load_state_dict(nested_state_dict["tae_module"])
        if self.object_detection_module is not None:
            self.object_detection_module.load_state_dict(nested_state_dict["od_module"])
        if self.symbolic_language_prediction_module is not None:
            self.symbolic_language_prediction_module.load_state_dict(nested_state_dict["sym_lang_module"])
        if self.goal_prediction_module is not None:
            self.goal_prediction_module.load_state_dict(nested_state_dict["goal_pred_module"])

    def load_resnet_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_)
        image_module_path = os.path.join(load_dir, "image_module_state.bin")
        self.image_module.load_state_dict(torch_load(image_module_path))

    def fix_resnet(self):
        self.image_module.fix_resnet()

    def load_lstm_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_)
        text_module_path = os.path.join(load_dir, "text_module_state.bin")
        self.text_module.load_state_dict(torch_load(text_module_path))

    def load_saved_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_)
        image_module_path = os.path.join(load_dir, "image_module_state.bin")
        self.image_module.load_state_dict(torch_load(image_module_path))
        image_recurrence_module_path = os.path.join(
            load_dir, "image_recurrence_module_state.bin")
        self.image_recurrence_module.load_state_dict(
            torch_load(image_recurrence_module_path))
        text_module_path = os.path.join(load_dir, "text_module_state.bin")
        self.text_module.load_state_dict(torch_load(text_module_path))
        action_module_path = os.path.join(load_dir, "action_module_state.bin")
        self.action_module.load_state_dict(torch_load(action_module_path))
        final_module_path = os.path.join(load_dir, "final_module_state.bin")
        self.final_module.load_state_dict(torch_load(final_module_path))
        if self.action_prediction_module is not None:
            auxiliary_action_prediction_path = os.path.join(load_dir, "auxiliary_action_prediction.bin")
            self.action_prediction_module.load_state_dict(torch_load(auxiliary_action_prediction_path))
        if self.temporal_autoencoder_module is not None:
            auxiliary_temporal_autoencoder_path = os.path.join(load_dir, "auxiliary_temporal_autoencoder.bin")
            self.temporal_autoencoder_module.load_state_dict(torch_load(auxiliary_temporal_autoencoder_path))
        if self.object_detection_module is not None:
            auxiliary_object_detection_path = os.path.join(load_dir, "auxiliary_object_detection.bin")
            self.object_detection_module.load_state_dict(torch_load(auxiliary_object_detection_path))
        if self.symbolic_language_prediction_module is not None:
            auxiliary_symbolic_language_prediction_path = os.path.join(
                load_dir, "auxiliary_symbolic_language_prediction.bin")
            self.symbolic_language_prediction_module.load_state_dict(
                torch_load(auxiliary_symbolic_language_prediction_path))
        if self.goal_prediction_module is not None:
            auxiliary_goal_prediction_path = os.path.join(load_dir, "auxiliary_goal_prediction.bin")
            self.goal_prediction_module.load_state_dict(torch_load(auxiliary_goal_prediction_path))

    def save_model(self, save_dir):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # save state file for image nn
        image_module_path = os.path.join(save_dir, "image_module_state.bin")
        torch.save(self.image_module.state_dict(), image_module_path)
        # save state file for image recurrence nn
        image_recurrence_module_path = os.path.join(
            save_dir, "image_recurrence_module_state.bin")
        torch.save(self.image_recurrence_module.state_dict(),
                   image_recurrence_module_path)
        # save state file for text nn
        text_module_path = os.path.join(save_dir, "text_module_state.bin")
        torch.save(self.text_module.state_dict(), text_module_path)
        # save state file for action emb
        action_module_path = os.path.join(save_dir, "action_module_state.bin")
        torch.save(self.action_module.state_dict(), action_module_path)
        # save state file for final nn
        final_module_path = os.path.join(save_dir, "final_module_state.bin")
        torch.save(self.final_module.state_dict(), final_module_path)
        # save the auxiliary models
        if self.action_prediction_module is not None:
            auxiliary_action_prediction_path = os.path.join(save_dir, "auxiliary_action_prediction.bin")
            torch.save(self.action_prediction_module.state_dict(), auxiliary_action_prediction_path)
        if self.temporal_autoencoder_module is not None:
            auxiliary_temporal_autoencoder_path = os.path.join(save_dir, "auxiliary_temporal_autoencoder.bin")
            torch.save(self.temporal_autoencoder_module.state_dict(), auxiliary_temporal_autoencoder_path)
        if self.object_detection_module is not None:
            auxiliary_object_detection_path = os.path.join(save_dir, "auxiliary_object_detection.bin")
            torch.save(self.object_detection_module.state_dict(), auxiliary_object_detection_path)
        if self.symbolic_language_prediction_module is not None:
            auxiliary_symbolic_language_prediction_path = os.path.join(
                save_dir, "auxiliary_symbolic_language_prediction.bin")
            torch.save(self.symbolic_language_prediction_module.state_dict(),
                       auxiliary_symbolic_language_prediction_path)
        if self.goal_prediction_module is not None:
            auxiliary_goal_prediction_path = os.path.join(save_dir, "auxiliary_goal_prediction.bin")
            torch.save(self.goal_prediction_module.state_dict(), auxiliary_goal_prediction_path)

    def get_parameters(self):
        parameters = list(self.image_module.parameters())
        parameters += list(self.image_recurrence_module.parameters())
        parameters += list(self.text_module.parameters())
        parameters += list(self.action_module.parameters())
        parameters += list(self.final_module.parameters())
        if self.action_prediction_module is not None:
            parameters += list(self.action_prediction_module.parameters())
        if self.temporal_autoencoder_module is not None:
            parameters += list(self.temporal_autoencoder_module.parameters())
        if self.object_detection_module is not None:
            parameters += list(self.object_detection_module.parameters())
        if self.symbolic_language_prediction_module is not None:
            parameters += list(self.symbolic_language_prediction_module.parameters())
        if self.goal_prediction_module is not None:
            parameters += list(self.goal_prediction_module.parameters())

        return parameters

    def get_named_parameters(self):
        named_parameters = list(self.image_module.named_parameters())
        named_parameters += list(self.image_recurrence_module.named_parameters())
        named_parameters += list(self.text_module.named_parameters())
        named_parameters += list(self.action_module.named_parameters())
        named_parameters += list(self.final_module.named_parameters())
        if self.action_prediction_module is not None:
            named_parameters += list(self.action_prediction_module.named_parameters())
        if self.temporal_autoencoder_module is not None:
            named_parameters += list(self.temporal_autoencoder_module.named_parameters())
        if self.object_detection_module is not None:
            named_parameters += list(self.object_detection_module.named_parameters())
        if self.symbolic_language_prediction_module is not None:
            named_parameters += list(self.symbolic_language_prediction_module.named_parameters())
        if self.goal_prediction_module is not None:
            named_parameters += list(self.goal_prediction_module.named_parameters())
        return named_parameters
Exemplo n.º 5
0
class IncrementalModelEmnlp(AbstractIncrementalModel):
    def __init__(self, config, constants):
        AbstractIncrementalModel.__init__(self, config, constants)
        self.none_action = config["num_actions"]

        self.config = config
        self.constants = constants

        # CNN over images - using what is essentially SimpleImage currently
        self.image_module = ImageCnnEmnlp(
            image_emb_size=constants["image_emb_dim"],
            input_num_channels=3 *
            5,  # 3 channels per image - 5 images in history
            image_height=config["image_height"],
            image_width=config["image_width"])

        # LSTM to embed text
        self.text_module = TextSimpleModule(
            emb_dim=constants["word_emb_dim"],
            hidden_dim=constants["lstm_emb_dim"],
            vocab_size=config["vocab_size"])

        # Action module to embed previous action+block
        self.action_module = ActionSimpleModule(
            num_actions=config["num_actions"],
            action_emb_size=constants["action_emb_dim"])

        # Put it all together
        self.final_module = IncrementalMultimodalEmnlp(
            image_module=self.image_module,
            text_module=self.text_module,
            action_module=self.action_module,
            input_embedding_size=constants["lstm_emb_dim"] +
            constants["image_emb_dim"] + constants["action_emb_dim"],
            output_hidden_size=config["h1_hidden_dim"],
            blocks_hidden_size=config["blocks_hidden_dim"],
            directions_hidden_size=config["action_hidden_dim"],
            max_episode_length=(constants["horizon"] + 5))

        if torch.cuda.is_available():
            self.image_module.cuda()
            self.text_module.cuda()
            self.action_module.cuda()
            self.final_module.cuda()

    def get_probs_batch(self, agent_observed_state_list, mode=None):
        raise NotImplementedError()

    def get_probs(self,
                  agent_observed_state,
                  model_state,
                  mode=None,
                  volatile=False):

        assert isinstance(agent_observed_state, AgentObservedState)

        # Image list is already padded with zero-images if <5 images are available
        images = agent_observed_state.get_image()[-5:]
        image_batch = cuda_var(
            torch.from_numpy(np.array(images)).float(), volatile)

        # Flatten them? TODO: maybe don't hardcode this later on? batch size is 1 ;)
        image_batch = image_batch.view(1, 15, self.config["image_height"],
                                       self.config["image_width"])

        # List of instructions. False is there because it expects a second argument. TODO: figure out what this is
        instructions_batch = ([agent_observed_state.get_instruction()], False)

        # Previous action
        prev_actions_raw = [agent_observed_state.get_previous_action()]

        # If previous action is non-existant then encode that as a stop?
        prev_actions = [
            self.none_action if a is None else a for a in prev_actions_raw
        ]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)))

        # Get probabilities
        probs_batch, new_model_state = self.final_module(
            image_batch, instructions_batch, prev_actions_batch, model_state)

        # last two we don't really need...
        return probs_batch, new_model_state, None, None

    def init_weights(self):
        self.text_module.init_weights()
        self.image_module.init_weights()
        self.action_module.init_weights()
        self.final_module.init_weights()

    def share_memory(self):
        self.image_module.share_memory()
        self.text_module.share_memory()
        self.action_module.share_memory()
        self.final_module.share_memory()

    def get_state_dict(self):
        nested_state_dict = dict()
        nested_state_dict["image_module"] = self.image_module.state_dict()
        nested_state_dict["text_module"] = self.text_module.state_dict()
        nested_state_dict["action_module"] = self.action_module.state_dict()
        nested_state_dict["final_module"] = self.final_module.state_dict()

        return nested_state_dict

    def load_from_state_dict(self, nested_state_dict):
        self.image_module.load_state_dict(nested_state_dict["image_module"])
        self.text_module.load_state_dict(nested_state_dict["text_module"])
        self.action_module.load_state_dict(nested_state_dict["action_module"])
        self.final_module.load_state_dict(nested_state_dict["final_module"])

    def load_resnet_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_,
                                               map_location=lambda s_, l_: s_)
        image_module_path = os.path.join(load_dir, "image_module_state.bin")
        self.image_module.load_state_dict(torch_load(image_module_path))

    def load_lstm_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_,
                                               map_location=lambda s_, l_: s_)
        text_module_path = os.path.join(load_dir, "text_module_state.bin")
        self.text_module.load_state_dict(torch_load(text_module_path))

    def load_saved_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_,
                                               map_location=lambda s_, l_: s_)
        image_module_path = os.path.join(load_dir, "image_module_state.bin")
        self.image_module.load_state_dict(torch_load(image_module_path))
        action_module_path = os.path.join(load_dir, "action_module_state.bin")
        self.action_module.load_state_dict(torch_load(action_module_path))
        text_module_path = os.path.join(load_dir, "text_module_state.bin")
        self.text_module.load_state_dict(torch_load(text_module_path))
        # action_module_path = os.path.join(load_dir, "action_module_state.bin")
        # self.action_module.load_state_dict(torch_load(action_module_path))
        final_module_path = os.path.join(load_dir, "final_module_state.bin")
        self.final_module.load_state_dict(torch_load(final_module_path))

    def save_model(self, save_dir):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # save state file for image nn
        image_module_path = os.path.join(save_dir, "image_module_state.bin")
        torch.save(self.image_module.state_dict(), image_module_path)
        # save state file for image recurrence nn
        action_module_path = os.path.join(save_dir, "action_module_state.bin")
        #torch.save(self.action_module.state_dict(),
        #           action_module_path)
        torch.save(self.action_module.state_dict(), action_module_path)
        # save state file for text nn
        text_module_path = os.path.join(save_dir, "text_module_state.bin")
        torch.save(self.text_module.state_dict(), text_module_path)
        # save state file for action emb
        # action_module_path = os.path.join(save_dir, "action_module_state.bin")
        # torch.save(self.action_module.state_dict(), action_module_path)
        # save state file for final nn
        final_module_path = os.path.join(save_dir, "final_module_state.bin")
        torch.save(self.final_module.state_dict(), final_module_path)

    def get_parameters(self):
        # parameters = list(self.image_module.parameters())
        # parameters += list(self.action_module.parameters())
        # parameters += list(self.text_module.parameters())
        parameters = list(self.final_module.parameters())

        return parameters

    def get_named_parameters(self):
        # named_parameters = list(self.image_module.named_parameters())
        # named_parameters += list(self.action_module.named_parameters())
        # named_parameters += list(self.text_module.named_parameters())
        named_parameters = list(self.final_module.named_parameters())
        return named_parameters
class ModelPolicyNetworkSymbolicText(AbstractModel):
    def __init__(self, config, constants):
        AbstractModel.__init__(self, config, constants)
        self.none_action = config["num_actions"]
        landmark_names = get_all_landmark_names()
        self.radius_module = RadiusModule(15)
        self.angle_module = AngleModule(48)
        self.landmark_module = LandmarkModule(63)

        self.image_module = ImageResnetModule(
            image_emb_size=constants["image_emb_dim"],
            input_num_channels=3,
            image_height=config["image_height"],
            image_width=config["image_width"],
            using_recurrence=True)
        self.image_recurrence_module = RecurrenceSimpleModule(
            input_emb_dim=constants["image_emb_dim"],
            output_emb_dim=constants["image_emb_dim"])

        self.text_module = SymbolicInstructionModule(
            radius_embedding=self.radius_module,
            theta_embedding=self.angle_module,
            landmark_embedding=self.landmark_module)
        self.action_module = ActionSimpleModule(
            num_actions=config["num_actions"],
            action_emb_size=constants["action_emb_dim"])
        total_emb_size = (constants["image_emb_dim"]
                          + 32 * 4
                          + constants["action_emb_dim"])
        final_module = MultimodalRecurrentSimpleModule(
            image_module=self.image_module,
            image_recurrence_module=self.image_recurrence_module,
            text_module=self.text_module,
            action_module=self.action_module,
            total_emb_size=total_emb_size,
            num_actions=config["num_actions"])
        self.final_module = final_module
        if torch.cuda.is_available():
            self.image_module.cuda()
            self.text_module.cuda()
            self.action_module.cuda()
            self.final_module.cuda()
            self.radius_module.cuda()
            self.angle_module.cuda()
            self.landmark_module.cuda()

    def get_probs_batch(self, agent_observed_state_list, mode=None):
        for aos in agent_observed_state_list:
            assert isinstance(aos, AgentObservedState)
        # print "batch size:", len(agent_observed_state_list)

        # sort list by instruction length
        agent_observed_state_list = sorted(
            agent_observed_state_list,
            key=lambda aos_: len(aos_.get_instruction()),
            reverse=True
        )

        image_seq_lens = [aos.get_num_images()
                          for aos in agent_observed_state_list]
        image_seq_lens_batch = cuda_tensor(
            torch.from_numpy(np.array(image_seq_lens)))
        max_len = max(image_seq_lens)
        image_seqs = [aos.get_image()[:max_len]
                      for aos in agent_observed_state_list]
        image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float())

        instructions_batch = [aos.get_symbolic_instruction()
                              for aos in agent_observed_state_list]

        prev_actions_raw = [aos.get_previous_action()
                            for aos in agent_observed_state_list]
        prev_actions = [self.none_action if a is None else a
                        for a in prev_actions_raw]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)))

        probs_batch = self.final_module(image_batch, image_seq_lens_batch,
                                        instructions_batch, prev_actions_batch,
                                        mode)
        return probs_batch

    def load_saved_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_)
        image_module_path = os.path.join(load_dir, "image_module_state.bin")
        self.image_module.load_state_dict(torch_load(image_module_path))
        text_module_path = os.path.join(load_dir, "text_module_state.bin")
        self.text_module.load_state_dict(torch_load(text_module_path))
        action_module_path = os.path.join(load_dir, "action_module_state.bin")
        self.action_module.load_state_dict(torch_load(action_module_path))
        final_module_path = os.path.join(load_dir, "final_module_state.bin")
        self.final_module.load_state_dict(torch_load(final_module_path))
        radius_module_path = os.path.join(load_dir, "radius_module_state.bin")
        self.radius_module.load_state_dict(torch_load(radius_module_path))
        angle_module_path = os.path.join(load_dir, "angle_module_state.bin")
        self.angle_module.load_state_dict(torch_load(angle_module_path))
        landmark_module_path = os.path.join(load_dir, "landmark_module_state.bin")
        self.landmark_module.load_state_dict(torch_load(landmark_module_path))

    def save_model(self, save_dir):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # save state file for image nn
        image_module_path = os.path.join(save_dir, "image_module_state.bin")
        torch.save(self.image_module.state_dict(), image_module_path)
        # save state file for text nn
        text_module_path = os.path.join(save_dir, "text_module_state.bin")
        torch.save(self.text_module.state_dict(), text_module_path)
        # save state file for action emb
        action_module_path = os.path.join(save_dir, "action_module_state.bin")
        torch.save(self.action_module.state_dict(), action_module_path)
        # save state file for final nn
        final_module_path = os.path.join(save_dir, "final_module_state.bin")
        torch.save(self.final_module.state_dict(), final_module_path)
        # save state file for radius nn
        radius_module_path = os.path.join(save_dir, "radius_module_state.bin")
        torch.save(self.radius_module.state_dict(), radius_module_path)
        # save state file for angle nn
        angle_module_path = os.path.join(save_dir, "angle_module_state.bin")
        torch.save(self.angle_module.state_dict(), angle_module_path)
        # save state file for landmark nn
        landmark_module_path = os.path.join(save_dir, "landmark_module_state.bin")
        torch.save(self.landmark_module.state_dict(), landmark_module_path)

    def get_parameters(self):
        parameters = list(self.image_module.parameters())
        parameters += list(self.text_module.parameters())
        parameters += list(self.action_module.parameters())
        parameters += list(self.final_module.parameters())
        parameters += list(self.radius_module.parameters())
        parameters += list(self.angle_module.parameters())
        parameters += list(self.landmark_module.parameters())
        return parameters
class IncrementalModelRecurrentPolicyNetworkSymbolicTextWithLSTMResnet(
        AbstractIncrementalModel):
    def __init__(self, config, constants):
        AbstractIncrementalModel.__init__(self, config, constants)
        self.none_action = config["num_actions"]
        landmark_names = get_all_landmark_names()
        self.radius_module = RadiusModule(15)
        self.angle_module = AngleModule(12)  # (48)
        self.landmark_module = LandmarkModule(67)
        self.num_cameras = 1
        self.image_module = ImageRyanResnetModule(
            image_emb_size=constants["image_emb_dim"],
            input_num_channels=3,
            image_height=config["image_height"],
            image_width=config["image_width"],
            using_recurrence=True)
        self.image_recurrence_module = IncrementalRecurrenceSimpleModule(
            input_emb_dim=constants["image_emb_dim"] *
            self.num_cameras,  # + constants["action_emb_dim"],
            output_emb_dim=constants["image_emb_dim"])
        self.symbolic_text_module = SymbolicInstructionModule(
            radius_embedding=self.radius_module,
            theta_embedding=self.angle_module,
            landmark_embedding=self.landmark_module)
        self.lstm_text_module = TextSimpleModule(
            emb_dim=constants["word_emb_dim"],
            hidden_dim=constants["lstm_emb_dim"],
            vocab_size=config["vocab_size"])
        self.action_module = ActionSimpleModule(
            num_actions=config["num_actions"],
            action_emb_size=constants["action_emb_dim"])
        total_emb_size = ((self.num_cameras) * constants["image_emb_dim"] +
                          32 * 2 + constants["lstm_emb_dim"] +
                          +constants["action_emb_dim"])

        if config["do_action_prediction"]:
            self.action_prediction_module = ActionPredictionModule(
                2 * self.num_cameras * constants["image_emb_dim"],
                constants["image_emb_dim"], config["num_actions"])
        else:
            self.action_prediction_module = None

        if config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_module = TemporalAutoencoderModule(
                self.action_module,
                self.num_cameras * constants["image_emb_dim"],
                constants["action_emb_dim"], constants["image_emb_dim"])
        else:
            self.temporal_autoencoder_module = None

        if config["do_object_detection"]:
            self.landmark_names = get_all_landmark_names()
            self.object_detection_module = ObjectDetectionModule(
                image_module=self.image_module,
                image_emb_size=self.num_cameras * constants["image_emb_dim"],
                num_objects=67)
        else:
            self.object_detection_module = None

        final_module = IncrementalMultimodalMixedTextRecurrentSimpleModule(
            image_module=self.image_module,
            image_recurrence_module=self.image_recurrence_module,
            symbolic_text_module=self.symbolic_text_module,
            lstm_text_module=self.lstm_text_module,
            action_module=self.action_module,
            total_emb_size=total_emb_size,
            num_actions=config["num_actions"])
        self.final_module = final_module
        if torch.cuda.is_available():
            self.image_module.cuda()
            self.image_recurrence_module.cuda()
            self.symbolic_text_module.cuda()
            self.lstm_text_module.cuda()
            self.action_module.cuda()
            self.final_module.cuda()
            if self.action_prediction_module is not None:
                self.action_prediction_module.cuda()
            if self.temporal_autoencoder_module is not None:
                self.temporal_autoencoder_module.cuda()
            if self.object_detection_module is not None:
                self.object_detection_module.cuda()

    def get_probs_batch(self, agent_observed_state_list, mode=None):
        raise NotImplementedError()
        for aos in agent_observed_state_list:
            assert isinstance(aos, AgentObservedState)
        # print "batch size:", len(agent_observed_state_list)

        # sort list by instruction length
        agent_observed_state_list = sorted(
            agent_observed_state_list,
            key=lambda aos_: len(aos_.get_instruction()),
            reverse=True)

        image_seq_lens = [
            aos.get_num_images() for aos in agent_observed_state_list
        ]
        image_seq_lens_batch = cuda_tensor(
            torch.from_numpy(np.array(image_seq_lens)))
        max_len = max(image_seq_lens)
        image_seqs = [
            aos.get_image()[:max_len] for aos in agent_observed_state_list
        ]
        image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float())

        instructions = [
            aos.get_instruction() for aos in agent_observed_state_list
        ]
        read_pointers = [
            aos.get_read_pointers() for aos in agent_observed_state_list
        ]
        instructions_batch = (instructions, read_pointers)

        prev_actions_raw = [
            aos.get_previous_action() for aos in agent_observed_state_list
        ]
        prev_actions = [
            self.none_action if a is None else a for a in prev_actions_raw
        ]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)))

        probs_batch, _ = self.final_module(image_batch,
                                           image_seq_lens_batch,
                                           instructions_batch,
                                           prev_actions_batch,
                                           mode,
                                           model_state=None)
        return probs_batch

    def get_probs(self,
                  agent_observed_state,
                  model_state,
                  mode=None,
                  volatile=False):

        assert isinstance(agent_observed_state, AgentObservedState)
        agent_observed_state_list = [agent_observed_state]

        image_seq_lens = [1]
        image_seq_lens_batch = cuda_tensor(
            torch.from_numpy(np.array(image_seq_lens)))
        image_seqs = [[aos.get_last_image()]
                      for aos in agent_observed_state_list]
        image_batch = cuda_var(
            torch.from_numpy(np.array(image_seqs)).float(), volatile)

        instructions_batch = [
            aos.get_symbolic_instruction() for aos in agent_observed_state_list
        ]
        instructions = [
            aos.get_instruction() for aos in agent_observed_state_list
        ]
        read_pointers = [
            aos.get_read_pointers() for aos in agent_observed_state_list
        ]
        lstm_instructions_batch = (instructions, read_pointers)

        prev_actions_raw = [
            aos.get_previous_action() for aos in agent_observed_state_list
        ]
        prev_actions = [
            self.none_action if a is None else a for a in prev_actions_raw
        ]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)),
                                      volatile)

        probs_batch, new_model_state, image_emb_seq, state_feature = self.final_module(
            image_batch, image_seq_lens_batch, instructions_batch,
            lstm_instructions_batch, prev_actions_batch, mode, model_state)
        return probs_batch, new_model_state, image_emb_seq, state_feature

    def get_probs_symbolic_text(self,
                                agent_observed_state,
                                symbolic_text,
                                model_state,
                                mode=None,
                                volatile=False):
        """ Same as get_probs instead forces the model to use the given symbolic text """

        assert isinstance(agent_observed_state, AgentObservedState)
        agent_observed_state_list = [agent_observed_state]

        image_seq_lens = [1]
        image_seq_lens_batch = cuda_tensor(
            torch.from_numpy(np.array(image_seq_lens)))
        image_seqs = [[aos.get_last_image()]
                      for aos in agent_observed_state_list]
        image_batch = cuda_var(
            torch.from_numpy(np.array(image_seqs)).float(), volatile)

        instructions_batch = [symbolic_text]

        instructions = [
            aos.get_instruction() for aos in agent_observed_state_list
        ]
        read_pointers = [
            aos.get_read_pointers() for aos in agent_observed_state_list
        ]
        real_instructions_batch = (instructions, read_pointers)

        prev_actions_raw = [
            aos.get_previous_action() for aos in agent_observed_state_list
        ]
        prev_actions = [
            self.none_action if a is None else a for a in prev_actions_raw
        ]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)),
                                      volatile)

        probs_batch, new_model_state, image_emb_seq, state_feature = self.final_module(
            image_batch, image_seq_lens_batch, instructions_batch,
            real_instructions_batch, prev_actions_batch, mode, model_state)
        return probs_batch, new_model_state, image_emb_seq, state_feature

    def action_prediction_log_prob(self, batch_input):
        assert self.action_prediction_module is not None, "Action prediction module not created. Check config."
        return self.action_prediction_module(batch_input)

    def predict_action_result(self, batch_image_feature, action_batch):
        assert self.temporal_autoencoder_module is not None, "Temporal action module not created. Check config."
        return self.temporal_autoencoder_module(batch_image_feature,
                                                action_batch)

    def get_probs_and_visible_objects(self, agent_observed_state_list,
                                      batch_image_feature):
        assert self.object_detection_module is not None, "Object detection module not created. Check config."
        landmarks_visible = []
        for aos in agent_observed_state_list:
            x_pos, z_pos, y_angle = aos.get_position_orientation()
            landmark_pos_dict = aos.get_landmark_pos_dict()
            visible_landmarks_dict = self.object_detection_module.get_visible_landmark_r_theta(
                x_pos, z_pos, y_angle, landmark_pos_dict)
            landmarks_visible.append(visible_landmarks_dict)

        # shape is BATCH_SIZE x num objects x 2
        landmark_log_prob, distance_log_prob, theta_log_prob = self.object_detection_module(
            batch_image_feature)

        # landmarks_visible is list of length BATCH_SIZE, each item is a set containing landmark indices
        return landmark_log_prob, distance_log_prob, theta_log_prob, landmarks_visible

    def init_weights(self):
        self.image_module.init_weights()
        self.image_recurrence_module.init_weights()

    def load_resnet_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_,
                                               map_location=lambda s_, l_: s_)
        image_module_path = os.path.join(load_dir, "image_module_state.bin")
        self.image_module.load_state_dict(torch_load(image_module_path))

    def load_saved_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_,
                                               map_location=lambda s_, l_: s_)
        image_module_path = os.path.join(load_dir, "image_module_state.bin")
        self.image_module.load_state_dict(torch_load(image_module_path))
        image_recurrence_module_path = os.path.join(
            load_dir, "image_recurrence_module_state.bin")
        self.image_recurrence_module.load_state_dict(
            torch_load(image_recurrence_module_path))
        symbolic_text_module_path = os.path.join(
            load_dir, "symbolic_text_module_state.bin")
        self.symbolic_text_module.load_state_dict(
            torch_load(symbolic_text_module_path))
        lstm_text_module_path = os.path.join(load_dir,
                                             "lstm_text_module_state.bin")
        self.lstm_text_module.load_state_dict(
            torch_load(lstm_text_module_path))
        action_module_path = os.path.join(load_dir, "action_module_state.bin")
        self.action_module.load_state_dict(torch_load(action_module_path))
        final_module_path = os.path.join(load_dir, "final_module_state.bin")
        self.final_module.load_state_dict(torch_load(final_module_path))
        if self.action_prediction_module is not None:
            auxiliary_action_prediction_path = os.path.join(
                load_dir, "auxiliary_action_prediction.bin")
            self.action_prediction_module.load_state_dict(
                torch_load(auxiliary_action_prediction_path))
        if self.temporal_autoencoder_module is not None:
            auxiliary_temporal_autoencoder_path = os.path.join(
                load_dir, "auxiliary_temporal_autoencoder.bin")
            self.temporal_autoencoder_module.load_state_dict(
                torch_load(auxiliary_temporal_autoencoder_path))
        if self.object_detection_module is not None:
            auxiliary_object_detection_path = os.path.join(
                load_dir, "auxiliary_object_detection.bin")
            self.object_detection_module.load_state_dict(
                torch_load(auxiliary_object_detection_path))

    def save_model(self, save_dir):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # save state file for image nn
        image_module_path = os.path.join(save_dir, "image_module_state.bin")
        torch.save(self.image_module.state_dict(), image_module_path)
        # save state file for image recurrence nn
        image_recurrence_module_path = os.path.join(
            save_dir, "image_recurrence_module_state.bin")
        torch.save(self.image_recurrence_module.state_dict(),
                   image_recurrence_module_path)
        # save state file for text nn
        symbolic_text_module_path = os.path.join(
            save_dir, "symbolic_text_module_state.bin")
        torch.save(self.symbolic_text_module.state_dict(),
                   symbolic_text_module_path)
        lstm_text_module_path = os.path.join(save_dir,
                                             "lstm_text_module_state.bin")
        torch.save(self.lstm_text_module.state_dict(), lstm_text_module_path)
        # save state file for action emb
        action_module_path = os.path.join(save_dir, "action_module_state.bin")
        torch.save(self.action_module.state_dict(), action_module_path)
        # save state file for final nn
        final_module_path = os.path.join(save_dir, "final_module_state.bin")
        torch.save(self.final_module.state_dict(), final_module_path)
        # save the auxiliary models
        if self.action_prediction_module is not None:
            auxiliary_action_prediction_path = os.path.join(
                save_dir, "auxiliary_action_prediction.bin")
            torch.save(self.action_prediction_module.state_dict(),
                       auxiliary_action_prediction_path)
        if self.temporal_autoencoder_module is not None:
            auxiliary_temporal_autoencoder_path = os.path.join(
                save_dir, "auxiliary_temporal_autoencoder.bin")
            torch.save(self.temporal_autoencoder_module.state_dict(),
                       auxiliary_temporal_autoencoder_path)
        if self.object_detection_module is not None:
            auxiliary_object_detection_path = os.path.join(
                save_dir, "auxiliary_object_detection.bin")
            torch.save(self.object_detection_module.state_dict(),
                       auxiliary_object_detection_path)

    def get_parameters(self):
        # parameters = list(self.image_module.parameters())
        parameters = list(self.image_recurrence_module.parameters())
        parameters += list(self.symbolic_text_module.parameters())
        parameters += list(self.lstm_text_module.parameters())
        parameters += list(self.action_module.parameters())
        parameters += list(self.final_module.parameters())
        if self.action_prediction_module is not None:
            parameters += list(self.action_prediction_module.parameters())
        if self.temporal_autoencoder_module is not None:
            parameters += list(self.temporal_autoencoder_module.parameters())
        if self.object_detection_module is not None:
            parameters += list(self.object_detection_module.parameters())
        return parameters

    def get_named_parameters(self):
        # named_parameters = list(self.image_module.named_parameters())
        named_parameters = list(
            self.image_recurrence_module.named_parameters())
        named_parameters += list(self.symbolic_text_module.named_parameters())
        named_parameters += list(self.lstm_text_module.named_parameters())
        named_parameters += list(self.action_module.named_parameters())
        named_parameters += list(self.final_module.named_parameters())
        if self.action_prediction_module is not None:
            named_parameters += list(
                self.action_prediction_module.named_parameters())
        if self.temporal_autoencoder_module is not None:
            named_parameters += list(
                self.temporal_autoencoder_module.named_parameters())
        if self.object_detection_module is not None:
            named_parameters += list(
                self.object_detection_module.named_parameters())
        '''if self.symbolic_language_prediction_module is not None:
            named_parameters += list(self.symbolic_language_prediction_module.named_parameters())'''
        return named_parameters
class IncrementalModelRecurrentPolicyNetworkGoalImageResnet(AbstractIncrementalModel):
    def __init__(self, config, constants):
        AbstractIncrementalModel.__init__(self, config, constants)
        self.none_action = config["num_actions"]

        self.image_module = ImageResnetModule(
            image_emb_size=constants["image_emb_dim"],
            input_num_channels=3,
            image_height=config["image_height"],
            image_width=config["image_width"],
            using_recurrence=True)
        # self.image_module = resnet.resnet18(pretrained=True)
        # constants["image_emb_dim"] = 1000
        self.image_recurrence_module = IncrementalRecurrenceSimpleModule(
            input_emb_dim=constants["image_emb_dim"],
            output_emb_dim=constants["image_emb_dim"])
        self.action_module = ActionSimpleModule(
            num_actions=config["num_actions"],
            action_emb_size=constants["action_emb_dim"])
        total_emb_size = (2 * constants["image_emb_dim"]
                          + constants["action_emb_dim"])

        if config["do_action_prediction"]:
            self.action_prediction_module = ActionPredictionModule(
                2 * constants["image_emb_dim"], constants["image_emb_dim"], config["num_actions"])
        else:
            self.action_prediction_module = None

        if config["do_temporal_autoencoding"]:
            self.temporal_autoencoder_module = TemporalAutoencoderModule(
                self.action_module, constants["image_emb_dim"], constants["action_emb_dim"], constants["image_emb_dim"])
        else:
            self.temporal_autoencoder_module = None

        if config["do_object_detection"]:
            self.landmark_names = get_all_landmark_names()
            self.object_detection_module = ObjectDetectionModule(
                image_module=self.image_module, image_emb_size=constants["image_emb_dim"], num_objects=63)
        else:
            self.object_detection_module = None

        final_module = IncrementalMultimodalRecurrentSimpleGoalImageModule(
            image_module=self.image_module,
            image_recurrence_module=self.image_recurrence_module,
            action_module=self.action_module,
            total_emb_size=total_emb_size,
            num_actions=config["num_actions"])
        self.final_module = final_module
        if torch.cuda.is_available():
            self.image_module.cuda()
            self.image_recurrence_module.cuda()
            self.action_module.cuda()
            self.final_module.cuda()
            if self.action_prediction_module is not None:
                self.action_prediction_module.cuda()
            if self.temporal_autoencoder_module is not None:
                self.temporal_autoencoder_module.cuda()
            if self.object_detection_module is not None:
                self.object_detection_module.cuda()

    def get_probs_batch(self, agent_observed_state_list, mode=None):
        raise NotImplementedError()
        for aos in agent_observed_state_list:
            assert isinstance(aos, AgentObservedState)
        # print "batch size:", len(agent_observed_state_list)

        # sort list by instruction length
        agent_observed_state_list = sorted(
            agent_observed_state_list,
            key=lambda aos_: len(aos_.get_instruction()),
            reverse=True
        )

        image_seq_lens = [aos.get_num_images()
                          for aos in agent_observed_state_list]
        image_seq_lens_batch = cuda_tensor(
            torch.from_numpy(np.array(image_seq_lens)))
        max_len = max(image_seq_lens)
        image_seqs = [aos.get_image()[:max_len]
                      for aos in agent_observed_state_list]
        image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float())

        instructions = [aos.get_instruction()
                        for aos in agent_observed_state_list]
        read_pointers = [aos.get_read_pointers()
                         for aos in agent_observed_state_list]
        instructions_batch = (instructions, read_pointers)

        prev_actions_raw = [aos.get_previous_action()
                            for aos in agent_observed_state_list]
        prev_actions = [self.none_action if a is None else a
                        for a in prev_actions_raw]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)))

        probs_batch, _ = self.final_module(image_batch, image_seq_lens_batch,
                                           instructions_batch, prev_actions_batch,
                                           mode, model_state=None)
        return probs_batch

    # def resize(self, img):
    #     img = img.swapaxes(0, 1).swapaxes(1, 2)
    #     resized_img = scipy.misc.imresize(img, (224, 224))
    #     return resized_img.swapaxes(1, 2).swapaxes(0, 1)
    #
    # def get_probs(self, agent_observed_state, model_state, mode=None):
    #
    #     assert isinstance(agent_observed_state, AgentObservedState)
    #     agent_observed_state_list = [agent_observed_state]
    #
    #     image_seq_lens = [1]
    #     image_seq_lens_batch = cuda_tensor(
    #         torch.from_numpy(np.array(image_seq_lens)))
    #     image_seqs = [self.resize(aos.get_last_image())
    #                   for aos in agent_observed_state_list]
    #     image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float())
    #
    #     goal_image_seqs = [self.resize(aos.get_goal_image()) for aos in agent_observed_state_list]
    #     goal_image_batch = cuda_var(torch.from_numpy(np.array(goal_image_seqs)).float())
    #
    #     prev_actions_raw = [aos.get_previous_action()
    #                         for aos in agent_observed_state_list]
    #     prev_actions = [self.none_action if a is None else a
    #                     for a in prev_actions_raw]
    #     prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)))
    #
    #     probs_batch, new_model_state, image_emb_seq = self.final_module(image_batch, image_seq_lens_batch,
    #                                                                     goal_image_batch, prev_actions_batch,
    #                                                                     mode, model_state)
    #     return probs_batch, new_model_state, image_emb_seq

    def get_probs(self, agent_observed_state, model_state, mode=None):

        assert isinstance(agent_observed_state, AgentObservedState)
        agent_observed_state_list = [agent_observed_state]

        image_seq_lens = [1]
        image_seq_lens_batch = cuda_tensor(
            torch.from_numpy(np.array(image_seq_lens)))
        image_seqs = [[aos.get_last_image()]
                      for aos in agent_observed_state_list]
        image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float())

        goal_image_seqs = [[aos.get_goal_image()] for aos in agent_observed_state_list]
        goal_image_batch = cuda_var(torch.from_numpy(np.array(goal_image_seqs)).float())

        prev_actions_raw = [aos.get_previous_action()
                            for aos in agent_observed_state_list]
        prev_actions = [self.none_action if a is None else a
                        for a in prev_actions_raw]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)))

        probs_batch, new_model_state, image_emb_seq = self.final_module(image_batch, image_seq_lens_batch,
                                                                        goal_image_batch, prev_actions_batch,
                                                                        mode, model_state)
        return probs_batch, new_model_state, image_emb_seq

    def action_prediction_log_prob(self, batch_input):
        assert self.action_prediction_module is not None, "Action prediction module not created. Check config."
        return self.action_prediction_module(batch_input)

    def predict_action_result(self, batch_image_feature, action_batch):
        assert self.temporal_autoencoder_module is not None, "Temporal action module not created. Check config."
        return self.temporal_autoencoder_module(batch_image_feature, action_batch)

    def get_probs_and_visible_objects(self, agent_observed_state_list, batch_image_feature):
        assert self.object_detection_module is not None, "Object detection module not created. Check config."
        landmarks_visible = []
        for aos in agent_observed_state_list:
            x_pos, z_pos, y_angle = aos.get_position_orientation()
            landmark_pos_dict = aos.get_landmark_pos_dict()
            visible_landmarks = self.object_detection_module.get_visible_landmark_r_theta(
                x_pos, z_pos, y_angle, landmark_pos_dict, self.landmark_names)
            landmarks_visible.append(visible_landmarks)

        # shape is BATCH_SIZE x 63 x 2
        probs_batch = self.object_detection_module(batch_image_feature)

        # landmarks_visible is list of length BATCH_SIZE, each item is a set containing landmark indices
        return probs_batch, landmarks_visible

    def load_resnet_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_)
        image_module_path = os.path.join(load_dir, "image_module_state.bin")
        self.image_module.load_state_dict(torch_load(image_module_path))

    def load_saved_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_)
        image_module_path = os.path.join(load_dir, "image_module_state.bin")
        self.image_module.load_state_dict(torch_load(image_module_path))
        image_recurrence_module_path = os.path.join(
            load_dir, "image_recurrence_module_state.bin")
        self.image_recurrence_module.load_state_dict(
            torch_load(image_recurrence_module_path))
        action_module_path = os.path.join(load_dir, "action_module_state.bin")
        self.action_module.load_state_dict(torch_load(action_module_path))
        final_module_path = os.path.join(load_dir, "final_module_state.bin")
        self.final_module.load_state_dict(torch_load(final_module_path))
        if self.action_prediction_module is not None:
            auxiliary_action_prediction_path = os.path.join(load_dir, "auxiliary_action_prediction.bin")
            self.action_prediction_module.load_state_dict(torch_load(auxiliary_action_prediction_path))
        if self.temporal_autoencoder_module is not None:
            auxiliary_temporal_autoencoder_path = os.path.join(load_dir, "auxiliary_temporal_autoencoder.bin")
            self.temporal_autoencoder_module.load_state_dict(torch_load(auxiliary_temporal_autoencoder_path))
        if self.object_detection_module is not None:
            auxiliary_object_detection_path = os.path.join(load_dir, "auxiliary_object_detection.bin")
            self.object_detection_module.load_state_dict(torch_load(auxiliary_object_detection_path))

    def save_model(self, save_dir):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # save state file for image nn
        image_module_path = os.path.join(save_dir, "image_module_state.bin")
        torch.save(self.image_module.state_dict(), image_module_path)
        # save state file for image recurrence nn
        image_recurrence_module_path = os.path.join(
            save_dir, "image_recurrence_module_state.bin")
        torch.save(self.image_recurrence_module.state_dict(),
                   image_recurrence_module_path)
        # save state file for action emb
        action_module_path = os.path.join(save_dir, "action_module_state.bin")
        torch.save(self.action_module.state_dict(), action_module_path)
        # save state file for final nn
        final_module_path = os.path.join(save_dir, "final_module_state.bin")
        torch.save(self.final_module.state_dict(), final_module_path)
        # save the auxiliary models
        if self.action_prediction_module is not None:
            auxiliary_action_prediction_path = os.path.join(save_dir, "auxiliary_action_prediction.bin")
            torch.save(self.action_prediction_module.state_dict(), auxiliary_action_prediction_path)
        if self.temporal_autoencoder_module is not None:
            auxiliary_temporal_autoencoder_path = os.path.join(save_dir, "auxiliary_temporal_autoencoder.bin")
            torch.save(self.temporal_autoencoder_module.state_dict(), auxiliary_temporal_autoencoder_path)
        if self.object_detection_module is not None:
            auxiliary_object_detection_path = os.path.join(save_dir, "auxiliary_object_detection.bin")
            torch.save(self.object_detection_module.state_dict(), auxiliary_object_detection_path)

    def get_parameters(self):
        parameters = list(self.image_module.parameters())
        parameters += list(self.image_recurrence_module.parameters())
        parameters += list(self.action_module.parameters())
        parameters += list(self.final_module.parameters())
        if self.action_prediction_module is not None:
            parameters += list(self.action_prediction_module.parameters())
        if self.temporal_autoencoder_module is not None:
            parameters += list(self.temporal_autoencoder_module.parameters())
        if self.object_detection_module is not None:
            parameters += list(self.object_detection_module.parameters())
        return parameters
Exemplo n.º 9
0
class IncrementalModelRecurrentImplicitFactorizationResnet(
        AbstractIncrementalModel):
    def __init__(self, config, constants):
        AbstractIncrementalModel.__init__(self, config, constants)
        self.none_action = config["num_actions"]
        self.image_module = ImageResnetModule(
            image_emb_size=constants["image_emb_dim"],
            input_num_channels=3,
            image_height=config["image_height"],
            image_width=config["image_width"],
            using_recurrence=True)
        self.image_recurrence_module = IncrementalRecurrenceSimpleModule(
            input_emb_dim=constants["image_emb_dim"],
            output_emb_dim=constants["image_emb_dim"])
        if config["use_pointer_model"]:
            raise AssertionError("Not implemented")
            # self.text_module = TextPointerModule(
            #     emb_dim=constants["word_emb_dim"],
            #     hidden_dim=constants["lstm_emb_dim"],
            #     vocab_size=config["vocab_size"])
        else:
            self.text_module = TextImplicitFactorizationModule(
                emb_dim=constants["word_emb_dim"],
                hidden_dim=constants["lstm_emb_dim"],
                vocab_size=config["vocab_size"],
                num_factors=2,
                factors_vocabulary_size=60,
                factors_embedding_size=250)
        self.action_module = ActionSimpleModule(
            num_actions=config["num_actions"],
            action_emb_size=constants["action_emb_dim"])
        if config["use_pointer_model"]:
            total_emb_size = (constants["image_emb_dim"] +
                              4 * constants["lstm_emb_dim"] +
                              constants["action_emb_dim"])
        else:
            total_emb_size = (constants["image_emb_dim"] + 2 * 250 +
                              constants["action_emb_dim"])
        final_module = IncrementalMultimodalRecurrentSimpleModule(
            image_module=self.image_module,
            image_recurrence_module=self.image_recurrence_module,
            text_module=self.text_module,
            action_module=self.action_module,
            total_emb_size=total_emb_size,
            num_actions=config["num_actions"])
        self.final_module = final_module
        if torch.cuda.is_available():
            self.image_module.cuda()
            self.image_recurrence_module.cuda()
            self.text_module.cuda()
            self.action_module.cuda()
            self.final_module.cuda()

    def get_probs_batch(self, agent_observed_state_list, mode=None):
        for aos in agent_observed_state_list:
            assert isinstance(aos, AgentObservedState)
        # print "batch size:", len(agent_observed_state_list)

        # sort list by instruction length
        agent_observed_state_list = sorted(
            agent_observed_state_list,
            key=lambda aos_: len(aos_.get_instruction()),
            reverse=True)

        image_seq_lens = [
            aos.get_num_images() for aos in agent_observed_state_list
        ]
        image_seq_lens_batch = cuda_tensor(
            torch.from_numpy(np.array(image_seq_lens)))
        max_len = max(image_seq_lens)
        image_seqs = [
            aos.get_image()[:max_len] for aos in agent_observed_state_list
        ]
        image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float())

        instructions = [
            aos.get_instruction() for aos in agent_observed_state_list
        ]
        read_pointers = [
            aos.get_read_pointers() for aos in agent_observed_state_list
        ]
        instructions_batch = (instructions, read_pointers)

        prev_actions_raw = [
            aos.get_previous_action() for aos in agent_observed_state_list
        ]
        prev_actions = [
            self.none_action if a is None else a for a in prev_actions_raw
        ]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)))

        probs_batch, _ = self.final_module(image_batch,
                                           image_seq_lens_batch,
                                           instructions_batch,
                                           prev_actions_batch,
                                           mode,
                                           model_state=None)
        return probs_batch

    def get_probs(self, agent_observed_state, model_state, mode=None):

        assert isinstance(agent_observed_state, AgentObservedState)
        agent_observed_state_list = [agent_observed_state]

        image_seq_lens = [1]
        image_seq_lens_batch = cuda_tensor(
            torch.from_numpy(np.array(image_seq_lens)))
        # max_len = max(image_seq_lens)
        # image_seqs = [aos.get_image()[:max_len]
        #               for aos in agent_observed_state_list]
        image_seqs = [[aos.get_last_image()]
                      for aos in agent_observed_state_list]
        image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float())

        instructions = [
            aos.get_instruction() for aos in agent_observed_state_list
        ]
        read_pointers = [
            aos.get_read_pointers() for aos in agent_observed_state_list
        ]
        instructions_batch = (instructions, read_pointers)

        prev_actions_raw = [
            aos.get_previous_action() for aos in agent_observed_state_list
        ]
        prev_actions = [
            self.none_action if a is None else a for a in prev_actions_raw
        ]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)))

        probs_batch, new_model_state, image_emb_seq = self.final_module(
            image_batch, image_seq_lens_batch, instructions_batch,
            prev_actions_batch, mode, model_state)
        return probs_batch, new_model_state, image_emb_seq

    def get_recent_factorization_entropy(self):
        return self.text_module.mean_factory_entropy

    def load_saved_model(self, load_dir):
        if torch.cuda.is_available():
            torch_load = torch.load
        else:
            torch_load = lambda f_: torch.load(f_,
                                               map_location=lambda s_, l_: s_)
        image_module_path = os.path.join(load_dir, "image_module_state.bin")
        self.image_module.load_state_dict(torch_load(image_module_path))
        image_recurrence_module_path = os.path.join(
            load_dir, "image_recurrence_module_state.bin")
        self.image_recurrence_module.load_state_dict(
            torch_load(image_recurrence_module_path))
        text_module_path = os.path.join(load_dir, "text_module_state.bin")
        self.text_module.load_state_dict(torch_load(text_module_path))
        action_module_path = os.path.join(load_dir, "action_module_state.bin")
        self.action_module.load_state_dict(torch_load(action_module_path))
        final_module_path = os.path.join(load_dir, "final_module_state.bin")
        self.final_module.load_state_dict(torch_load(final_module_path))

    def save_model(self, save_dir):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # save state file for image nn
        image_module_path = os.path.join(save_dir, "image_module_state.bin")
        torch.save(self.image_module.state_dict(), image_module_path)
        # save state file for image recurrence nn
        image_recurrence_module_path = os.path.join(
            save_dir, "image_recurrence_module_state.bin")
        torch.save(self.image_recurrence_module.state_dict(),
                   image_recurrence_module_path)
        # save state file for text nn
        text_module_path = os.path.join(save_dir, "text_module_state.bin")
        torch.save(self.text_module.state_dict(), text_module_path)
        # save state file for action emb
        action_module_path = os.path.join(save_dir, "action_module_state.bin")
        torch.save(self.action_module.state_dict(), action_module_path)
        # save state file for final nn
        final_module_path = os.path.join(save_dir, "final_module_state.bin")
        torch.save(self.final_module.state_dict(), final_module_path)

    def get_parameters(self):
        parameters = list(self.image_module.parameters())
        parameters += list(self.image_recurrence_module.parameters())
        parameters += list(self.text_module.parameters())
        parameters += list(self.action_module.parameters())
        parameters += list(self.final_module.parameters())
        return parameters