class IncrementalModelEmnlp(AbstractIncrementalModel): def __init__(self, config, constants): AbstractIncrementalModel.__init__(self, config, constants) self.none_action = config["num_actions"] self.config = config self.constants = constants # CNN over images - using SimpleImage for testing for now! self.image_module = ImageCnnEmnlp( image_emb_size=config["image_emb_dim"], input_num_channels=3 * 5, #3 channels per image - 5 images in history image_height=config["image_height"], image_width=config["image_width"]) # this is somewhat counter intuitivie - emb_dim is the word size # hidden_size is the output size self.text_module = TextSimpleModule(emb_dim=config["word_emb_dim"], hidden_dim=config["lstm_emb_dim"], vocab_size=config["vocab_size"]) self.previous_action_module = ActionSimpleModule( num_actions=config["no_actions"], action_emb_size=config["previous_action_embedding_dim"]) self.previous_block_module = ActionSimpleModule( num_actions=config["no_blocks"], action_emb_size=config["previous_block_embedding_dim"]) self.final_module = IncrementalMultimodalEmnlp( image_module=self.image_module, text_module=self.text_module, previous_action_module=self.previous_action_module, previous_block_module=self.previous_block_module, input_embedding_size=config["lstm_emb_dim"] + config["image_emb_dim"] + config["previous_action_embedding_dim"] + config["previous_block_embedding_dim"], output_hidden_size=config["h1_hidden_dim"], blocks_hidden_size=config["no_blocks"], directions_hidden_size=config["no_actions"], max_episode_length=(constants["horizon"] + 5)) if torch.cuda.is_available(): self.image_module.cuda() self.text_module.cuda() self.previous_action_module.cuda() self.previous_block_module.cuda() self.final_module.cuda() def get_probs_batch(self, agent_observed_state_list, mode=None): raise NotImplementedError() def get_probs(self, agent_observed_state, model_state, mode=None, volatile=False): assert isinstance(agent_observed_state, AgentObservedState) #supposedly this is already padded with zeros, but i need to double check that code images = agent_observed_state.get_image()[-5:] # image_seqs = [[aos.get_last_image()] # for aos in agent_observed_state_list] image_batch = cuda_var( torch.from_numpy(np.array(images)).float(), volatile) #flatten them? TODO: maybe don't hardcode this later on? batch size is 1 ;) image_batch = image_batch.view(1, 15, 128, 128) # list of list :) instructions_batch = ([agent_observed_state.get_instruction()], False) #instructions_batch = (cuda_var(torch.from_numpy(np.array(instructions)).long()), False) #print("instructions", instructions) #print("instructins_batch", instructions_batch) prev_actions_raw = agent_observed_state.get_previous_action() prev_actions_raw = self.none_action if prev_actions_raw is None else prev_actions_raw if prev_actions_raw == 81: previous_direction_id = [4] else: previous_direction_id = [prev_actions_raw % 4] #this input is is over the space 81 things :) previous_block_id = [int(prev_actions_raw / 4)] prev_block_id_batch = cuda_var( torch.from_numpy(np.array(previous_block_id))) prev_direction_id_batch = cuda_var( torch.from_numpy(np.array(previous_direction_id))) # prev_actions = [self.none_action if a is None else a # for a in prev_actions_raw] #prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions))) probs_batch, new_model_state = self.final_module( image_batch, instructions_batch, prev_block_id_batch, prev_direction_id_batch, model_state) # last two we don't really need... return probs_batch, new_model_state, None, None def init_weights(self): self.text_module.init_weights() self.image_module.init_weights() self.previous_action_module.init_weights() self.previous_block_module.init_weights() self.final_module.init_weights() def share_memory(self): self.image_module.share_memory() self.text_module.share_memory() self.previous_action_module.share_memory() self.previous_block_module.share_memory() self.final_module.share_memory() def get_state_dict(self): nested_state_dict = dict() nested_state_dict["image_module"] = self.image_module.state_dict() nested_state_dict["text_module"] = self.text_module.state_dict() nested_state_dict[ "previous_action_module"] = self.previous_action_module.state_dict( ) nested_state_dict[ "previous_block_module"] = self.previous_block_module.state_dict() nested_state_dict["final_module"] = self.final_module.state_dict() return nested_state_dict def load_from_state_dict(self, nested_state_dict): self.image_module.load_state_dict(nested_state_dict["image_module"]) self.text_module.load_state_dict(nested_state_dict["text_module"]) self.previous_action_module.load_state_dict( nested_state_dict["previous_action_module"]) self.previous_block_module.load_state_dict( nested_state_dict["previous_block_module"]) self.final_module.load_state_dict(nested_state_dict["final_module"]) def load_resnet_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) def load_lstm_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) text_module_path = os.path.join(load_dir, "text_module_state.bin") self.text_module.load_state_dict(torch_load(text_module_path)) def load_saved_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) previous_action_module_path = os.path.join( load_dir, "previous_action_module_state.bin") self.previous_action_module.load_state_dict( torch_load(previous_action_module_path)) previous_block_module_path = os.path.join( load_dir, "previous_block_module_state.bin") self.previous_block_module.load_state_dict( torch_load(previous_block_module_path)) text_module_path = os.path.join(load_dir, "text_module_state.bin") self.text_module.load_state_dict(torch_load(text_module_path)) # action_module_path = os.path.join(load_dir, "action_module_state.bin") # self.action_module.load_state_dict(torch_load(action_module_path)) final_module_path = os.path.join(load_dir, "final_module_state.bin") self.final_module.load_state_dict(torch_load(final_module_path)) def save_model(self, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) # save state file for image nn image_module_path = os.path.join(save_dir, "image_module_state.bin") torch.save(self.image_module.state_dict(), image_module_path) # save state file for image recurrence nn previous_action_module_path = os.path.join( save_dir, "previous_action_module_state.bin") torch.save(self.previous_action_module.state_dict(), previous_action_module_path) previous_block_module_path = os.path.join( save_dir, "previous_block_module_state.bin") torch.save(self.previous_block_module.state_dict(), previous_block_module_path) # save state file for text nn text_module_path = os.path.join(save_dir, "text_module_state.bin") torch.save(self.text_module.state_dict(), text_module_path) # save state file for action emb # action_module_path = os.path.join(save_dir, "action_module_state.bin") # torch.save(self.action_module.state_dict(), action_module_path) # save state file for final nn final_module_path = os.path.join(save_dir, "final_module_state.bin") torch.save(self.final_module.state_dict(), final_module_path) def get_parameters(self): # parameters = list(self.image_module.parameters()) # parameters += list(self.action_module.parameters()) # parameters += list(self.text_module.parameters()) parameters = list(self.final_module.parameters()) return parameters def get_named_parameters(self): # named_parameters = list(self.image_module.named_parameters()) # named_parameters += list(self.action_module.named_parameters()) # named_parameters += list(self.text_module.named_parameters()) named_parameters = list(self.final_module.named_parameters()) return named_parameters
class IncrementalModelEmnlp(AbstractIncrementalModel): def __init__(self, config, constants): AbstractIncrementalModel.__init__(self, config, constants) self.none_action = config["num_actions"] self.config = config self.constants = constants # CNN over images - using what is essentially SimpleImage currently self.image_module = ImageCnnEmnlp( image_emb_size=constants["image_emb_dim"], input_num_channels=3 * 5, # 3 channels per image - 5 images in history image_height=config["image_height"], image_width=config["image_width"]) # LSTM to embed text self.text_module = TextSimpleModule( emb_dim=constants["word_emb_dim"], hidden_dim=constants["lstm_emb_dim"], vocab_size=config["vocab_size"]) # Action module to embed previous action+block self.action_module = ActionSimpleModule( num_actions=config["num_actions"], action_emb_size=constants["action_emb_dim"]) # Put it all together self.final_module = IncrementalMultimodalEmnlp( image_module=self.image_module, text_module=self.text_module, action_module=self.action_module, input_embedding_size=constants["lstm_emb_dim"] + constants["image_emb_dim"] + constants["action_emb_dim"], output_hidden_size=config["h1_hidden_dim"], blocks_hidden_size=config["blocks_hidden_dim"], directions_hidden_size=config["action_hidden_dim"], max_episode_length=(constants["horizon"] + 5)) if torch.cuda.is_available(): self.image_module.cuda() self.text_module.cuda() self.action_module.cuda() self.final_module.cuda() def get_probs_batch(self, agent_observed_state_list, mode=None): raise NotImplementedError() def get_probs(self, agent_observed_state, model_state, mode=None, volatile=False): assert isinstance(agent_observed_state, AgentObservedState) # Image list is already padded with zero-images if <5 images are available images = agent_observed_state.get_image()[-5:] image_batch = cuda_var( torch.from_numpy(np.array(images)).float(), volatile) # Flatten them? TODO: maybe don't hardcode this later on? batch size is 1 ;) image_batch = image_batch.view(1, 15, self.config["image_height"], self.config["image_width"]) # List of instructions. False is there because it expects a second argument. TODO: figure out what this is instructions_batch = ([agent_observed_state.get_instruction()], False) # Previous action prev_actions_raw = [agent_observed_state.get_previous_action()] # If previous action is non-existant then encode that as a stop? prev_actions = [ self.none_action if a is None else a for a in prev_actions_raw ] prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions))) # Get probabilities probs_batch, new_model_state = self.final_module( image_batch, instructions_batch, prev_actions_batch, model_state) # last two we don't really need... return probs_batch, new_model_state, None, None def init_weights(self): self.text_module.init_weights() self.image_module.init_weights() self.action_module.init_weights() self.final_module.init_weights() def share_memory(self): self.image_module.share_memory() self.text_module.share_memory() self.action_module.share_memory() self.final_module.share_memory() def get_state_dict(self): nested_state_dict = dict() nested_state_dict["image_module"] = self.image_module.state_dict() nested_state_dict["text_module"] = self.text_module.state_dict() nested_state_dict["action_module"] = self.action_module.state_dict() nested_state_dict["final_module"] = self.final_module.state_dict() return nested_state_dict def load_from_state_dict(self, nested_state_dict): self.image_module.load_state_dict(nested_state_dict["image_module"]) self.text_module.load_state_dict(nested_state_dict["text_module"]) self.action_module.load_state_dict(nested_state_dict["action_module"]) self.final_module.load_state_dict(nested_state_dict["final_module"]) def load_resnet_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) def load_lstm_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) text_module_path = os.path.join(load_dir, "text_module_state.bin") self.text_module.load_state_dict(torch_load(text_module_path)) def load_saved_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) action_module_path = os.path.join(load_dir, "action_module_state.bin") self.action_module.load_state_dict(torch_load(action_module_path)) text_module_path = os.path.join(load_dir, "text_module_state.bin") self.text_module.load_state_dict(torch_load(text_module_path)) # action_module_path = os.path.join(load_dir, "action_module_state.bin") # self.action_module.load_state_dict(torch_load(action_module_path)) final_module_path = os.path.join(load_dir, "final_module_state.bin") self.final_module.load_state_dict(torch_load(final_module_path)) def save_model(self, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) # save state file for image nn image_module_path = os.path.join(save_dir, "image_module_state.bin") torch.save(self.image_module.state_dict(), image_module_path) # save state file for image recurrence nn action_module_path = os.path.join(save_dir, "action_module_state.bin") #torch.save(self.action_module.state_dict(), # action_module_path) torch.save(self.action_module.state_dict(), action_module_path) # save state file for text nn text_module_path = os.path.join(save_dir, "text_module_state.bin") torch.save(self.text_module.state_dict(), text_module_path) # save state file for action emb # action_module_path = os.path.join(save_dir, "action_module_state.bin") # torch.save(self.action_module.state_dict(), action_module_path) # save state file for final nn final_module_path = os.path.join(save_dir, "final_module_state.bin") torch.save(self.final_module.state_dict(), final_module_path) def get_parameters(self): # parameters = list(self.image_module.parameters()) # parameters += list(self.action_module.parameters()) # parameters += list(self.text_module.parameters()) parameters = list(self.final_module.parameters()) return parameters def get_named_parameters(self): # named_parameters = list(self.image_module.named_parameters()) # named_parameters += list(self.action_module.named_parameters()) # named_parameters += list(self.text_module.named_parameters()) named_parameters = list(self.final_module.named_parameters()) return named_parameters