class ModelPolicyNetworkResnetWithStop(AbstractModel): def __init__(self, config, constants): AbstractModel.__init__(self, config, constants) self.none_action = config["num_actions"] self.image_module = ImageResnetModule( image_emb_size=constants["image_emb_dim"], input_num_channels=3*constants["max_num_images"], image_height=config["image_height"], image_width=config["image_width"]) if config["use_pointer_model"]: self.text_module = TextPointerModule( emb_dim=constants["word_emb_dim"], hidden_dim=constants["lstm_emb_dim"], vocab_size=config["vocab_size"]) else: self.text_module = TextSimpleModule( emb_dim=constants["word_emb_dim"], hidden_dim=constants["lstm_emb_dim"], vocab_size=config["vocab_size"]) self.action_module = ActionSimpleModule( num_actions=config["num_actions"], action_emb_size=constants["action_emb_dim"]) total_emb_size = (constants["image_emb_dim"] + constants["lstm_emb_dim"] + constants["action_emb_dim"]) final_module = MultimodalSimpleWithStopModule( image_module=self.image_module, text_module=self.text_module, action_module=self.action_module, total_emb_size=total_emb_size, num_actions=config["num_actions"]) self.final_module = final_module if torch.cuda.is_available(): self.image_module.cuda() self.text_module.cuda() self.action_module.cuda() self.final_module.cuda() def get_probs_batch(self, agent_observed_state_list, mode=None): for aos in agent_observed_state_list: assert isinstance(aos, AgentObservedState) # print "batch size:", len(agent_observed_state_list) # sort list by instruction length agent_observed_state_list = sorted( agent_observed_state_list, key=lambda aos_: len(aos_.get_instruction()), reverse=True ) images = [aos.get_image() for aos in agent_observed_state_list] image_batch = cuda_var(torch.from_numpy(np.array(images)).float()) instructions = [aos.get_instruction() for aos in agent_observed_state_list] read_pointers = [aos.get_read_pointers() for aos in agent_observed_state_list] instructions_batch = (instructions, read_pointers) prev_actions_raw = [aos.get_previous_action() for aos in agent_observed_state_list] prev_actions = [self.none_action if a is None else a for a in prev_actions_raw] prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions))) probs_batch = self.final_module(image_batch, instructions_batch, prev_actions_batch, mode) return probs_batch def load_saved_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) text_module_path = os.path.join(load_dir, "text_module_state.bin") self.text_module.load_state_dict(torch_load(text_module_path)) action_module_path = os.path.join(load_dir, "action_module_state.bin") self.action_module.load_state_dict(torch_load(action_module_path)) final_module_path = os.path.join(load_dir, "final_module_state.bin") self.final_module.load_state_dict(torch_load(final_module_path)) def save_model(self, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) # save state file for image nn image_module_path = os.path.join(save_dir, "image_module_state.bin") torch.save(self.image_module.state_dict(), image_module_path) # save state file for text nn text_module_path = os.path.join(save_dir, "text_module_state.bin") torch.save(self.text_module.state_dict(), text_module_path) # save state file for action emb action_module_path = os.path.join(save_dir, "action_module_state.bin") torch.save(self.action_module.state_dict(), action_module_path) # save state file for final nn final_module_path = os.path.join(save_dir, "final_module_state.bin") torch.save(self.final_module.state_dict(), final_module_path) def get_parameters(self): parameters = list(self.image_module.parameters()) parameters += list(self.text_module.parameters()) parameters += list(self.action_module.parameters()) parameters += list(self.final_module.parameters()) return parameters
class IncrementalModelEmnlp(AbstractIncrementalModel): def __init__(self, config, constants): AbstractIncrementalModel.__init__(self, config, constants) self.none_action = config["num_actions"] self.config = config self.constants = constants # CNN over images - using SimpleImage for testing for now! self.image_module = ImageCnnEmnlp( image_emb_size=config["image_emb_dim"], input_num_channels=3 * 5, #3 channels per image - 5 images in history image_height=config["image_height"], image_width=config["image_width"]) # this is somewhat counter intuitivie - emb_dim is the word size # hidden_size is the output size self.text_module = TextSimpleModule(emb_dim=config["word_emb_dim"], hidden_dim=config["lstm_emb_dim"], vocab_size=config["vocab_size"]) self.previous_action_module = ActionSimpleModule( num_actions=config["no_actions"], action_emb_size=config["previous_action_embedding_dim"]) self.previous_block_module = ActionSimpleModule( num_actions=config["no_blocks"], action_emb_size=config["previous_block_embedding_dim"]) self.final_module = IncrementalMultimodalEmnlp( image_module=self.image_module, text_module=self.text_module, previous_action_module=self.previous_action_module, previous_block_module=self.previous_block_module, input_embedding_size=config["lstm_emb_dim"] + config["image_emb_dim"] + config["previous_action_embedding_dim"] + config["previous_block_embedding_dim"], output_hidden_size=config["h1_hidden_dim"], blocks_hidden_size=config["no_blocks"], directions_hidden_size=config["no_actions"], max_episode_length=(constants["horizon"] + 5)) if torch.cuda.is_available(): self.image_module.cuda() self.text_module.cuda() self.previous_action_module.cuda() self.previous_block_module.cuda() self.final_module.cuda() def get_probs_batch(self, agent_observed_state_list, mode=None): raise NotImplementedError() def get_probs(self, agent_observed_state, model_state, mode=None, volatile=False): assert isinstance(agent_observed_state, AgentObservedState) #supposedly this is already padded with zeros, but i need to double check that code images = agent_observed_state.get_image()[-5:] # image_seqs = [[aos.get_last_image()] # for aos in agent_observed_state_list] image_batch = cuda_var( torch.from_numpy(np.array(images)).float(), volatile) #flatten them? TODO: maybe don't hardcode this later on? batch size is 1 ;) image_batch = image_batch.view(1, 15, 128, 128) # list of list :) instructions_batch = ([agent_observed_state.get_instruction()], False) #instructions_batch = (cuda_var(torch.from_numpy(np.array(instructions)).long()), False) #print("instructions", instructions) #print("instructins_batch", instructions_batch) prev_actions_raw = agent_observed_state.get_previous_action() prev_actions_raw = self.none_action if prev_actions_raw is None else prev_actions_raw if prev_actions_raw == 81: previous_direction_id = [4] else: previous_direction_id = [prev_actions_raw % 4] #this input is is over the space 81 things :) previous_block_id = [int(prev_actions_raw / 4)] prev_block_id_batch = cuda_var( torch.from_numpy(np.array(previous_block_id))) prev_direction_id_batch = cuda_var( torch.from_numpy(np.array(previous_direction_id))) # prev_actions = [self.none_action if a is None else a # for a in prev_actions_raw] #prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions))) probs_batch, new_model_state = self.final_module( image_batch, instructions_batch, prev_block_id_batch, prev_direction_id_batch, model_state) # last two we don't really need... return probs_batch, new_model_state, None, None def init_weights(self): self.text_module.init_weights() self.image_module.init_weights() self.previous_action_module.init_weights() self.previous_block_module.init_weights() self.final_module.init_weights() def share_memory(self): self.image_module.share_memory() self.text_module.share_memory() self.previous_action_module.share_memory() self.previous_block_module.share_memory() self.final_module.share_memory() def get_state_dict(self): nested_state_dict = dict() nested_state_dict["image_module"] = self.image_module.state_dict() nested_state_dict["text_module"] = self.text_module.state_dict() nested_state_dict[ "previous_action_module"] = self.previous_action_module.state_dict( ) nested_state_dict[ "previous_block_module"] = self.previous_block_module.state_dict() nested_state_dict["final_module"] = self.final_module.state_dict() return nested_state_dict def load_from_state_dict(self, nested_state_dict): self.image_module.load_state_dict(nested_state_dict["image_module"]) self.text_module.load_state_dict(nested_state_dict["text_module"]) self.previous_action_module.load_state_dict( nested_state_dict["previous_action_module"]) self.previous_block_module.load_state_dict( nested_state_dict["previous_block_module"]) self.final_module.load_state_dict(nested_state_dict["final_module"]) def load_resnet_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) def load_lstm_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) text_module_path = os.path.join(load_dir, "text_module_state.bin") self.text_module.load_state_dict(torch_load(text_module_path)) def load_saved_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) previous_action_module_path = os.path.join( load_dir, "previous_action_module_state.bin") self.previous_action_module.load_state_dict( torch_load(previous_action_module_path)) previous_block_module_path = os.path.join( load_dir, "previous_block_module_state.bin") self.previous_block_module.load_state_dict( torch_load(previous_block_module_path)) text_module_path = os.path.join(load_dir, "text_module_state.bin") self.text_module.load_state_dict(torch_load(text_module_path)) # action_module_path = os.path.join(load_dir, "action_module_state.bin") # self.action_module.load_state_dict(torch_load(action_module_path)) final_module_path = os.path.join(load_dir, "final_module_state.bin") self.final_module.load_state_dict(torch_load(final_module_path)) def save_model(self, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) # save state file for image nn image_module_path = os.path.join(save_dir, "image_module_state.bin") torch.save(self.image_module.state_dict(), image_module_path) # save state file for image recurrence nn previous_action_module_path = os.path.join( save_dir, "previous_action_module_state.bin") torch.save(self.previous_action_module.state_dict(), previous_action_module_path) previous_block_module_path = os.path.join( save_dir, "previous_block_module_state.bin") torch.save(self.previous_block_module.state_dict(), previous_block_module_path) # save state file for text nn text_module_path = os.path.join(save_dir, "text_module_state.bin") torch.save(self.text_module.state_dict(), text_module_path) # save state file for action emb # action_module_path = os.path.join(save_dir, "action_module_state.bin") # torch.save(self.action_module.state_dict(), action_module_path) # save state file for final nn final_module_path = os.path.join(save_dir, "final_module_state.bin") torch.save(self.final_module.state_dict(), final_module_path) def get_parameters(self): # parameters = list(self.image_module.parameters()) # parameters += list(self.action_module.parameters()) # parameters += list(self.text_module.parameters()) parameters = list(self.final_module.parameters()) return parameters def get_named_parameters(self): # named_parameters = list(self.image_module.named_parameters()) # named_parameters += list(self.action_module.named_parameters()) # named_parameters += list(self.text_module.named_parameters()) named_parameters = list(self.final_module.named_parameters()) return named_parameters
class ModelPolicyNetworkSymbolic(AbstractModel): def __init__(self, config, constants): AbstractModel.__init__(self, config, constants) self.none_action = config["num_actions"] landmark_names = get_all_landmark_names() self.radius_module = RadiusModule(15) self.angle_module = AngleModule(48) self.landmark_module = LandmarkModule(63) self.image_module = SymbolicImageModule( landmark_names=landmark_names, radius_module=self.radius_module, angle_module=self.angle_module, landmark_module=self.landmark_module) if config["use_pointer_model"]: self.text_module = TextPointerModule( emb_dim=constants["word_emb_dim"], hidden_dim=constants["lstm_emb_dim"], vocab_size=config["vocab_size"]) else: self.text_module = TextSimpleModule( emb_dim=constants["word_emb_dim"], hidden_dim=constants["lstm_emb_dim"], vocab_size=config["vocab_size"]) self.action_module = ActionSimpleModule( num_actions=config["num_actions"], action_emb_size=constants["action_emb_dim"]) total_emb_size = (32 * 3 * 63 + constants["lstm_emb_dim"] + constants["action_emb_dim"]) final_module = MultimodalSimpleModule( image_module=self.image_module, text_module=self.text_module, action_module=self.action_module, total_emb_size=total_emb_size, num_actions=config["num_actions"]) self.final_module = final_module if torch.cuda.is_available(): self.image_module.cuda() self.text_module.cuda() self.action_module.cuda() self.final_module.cuda() self.radius_module.cuda() self.angle_module.cuda() self.landmark_module.cuda() def get_probs_batch(self, agent_observed_state_list, mode=None): for aos in agent_observed_state_list: assert isinstance(aos, AgentObservedState) # print "batch size:", len(agent_observed_state_list) # sort list by instruction length agent_observed_state_list = sorted( agent_observed_state_list, key=lambda aos_: len(aos_.get_instruction()), reverse=True ) symbolic_image_list = [] for aos in agent_observed_state_list: x_pos, z_pos, y_angle = aos.get_position_orientation() landmark_pos_dict = aos.get_landmark_pos_dict() symbolic_image = get_visible_landmark_r_theta( x_pos, z_pos, y_angle, landmark_pos_dict) symbolic_image_list.append(symbolic_image) image_batch = symbolic_image_list instructions = [aos.get_instruction() for aos in agent_observed_state_list] read_pointers = [aos.get_read_pointers() for aos in agent_observed_state_list] instructions_batch = (instructions, read_pointers) prev_actions_raw = [aos.get_previous_action() for aos in agent_observed_state_list] prev_actions = [self.none_action if a is None else a for a in prev_actions_raw] prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions))) probs_batch = self.final_module(image_batch, instructions_batch, prev_actions_batch, mode) return probs_batch def load_saved_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) text_module_path = os.path.join(load_dir, "text_module_state.bin") self.text_module.load_state_dict(torch_load(text_module_path)) action_module_path = os.path.join(load_dir, "action_module_state.bin") self.action_module.load_state_dict(torch_load(action_module_path)) final_module_path = os.path.join(load_dir, "final_module_state.bin") self.final_module.load_state_dict(torch_load(final_module_path)) self.final_module.load_state_dict(torch_load(final_module_path)) radius_module_path = os.path.join(load_dir, "radius_module_state.bin") self.radius_module.load_state_dict(torch_load(radius_module_path)) angle_module_path = os.path.join(load_dir, "angle_module_state.bin") self.angle_module.load_state_dict(torch_load(angle_module_path)) landmark_module_path = os.path.join(load_dir, "landmark_module_state.bin") self.landmark_module.load_state_dict(torch_load(landmark_module_path)) def save_model(self, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) # save state file for image nn image_module_path = os.path.join(save_dir, "image_module_state.bin") torch.save(self.image_module.state_dict(), image_module_path) # save state file for text nn text_module_path = os.path.join(save_dir, "text_module_state.bin") torch.save(self.text_module.state_dict(), text_module_path) # save state file for action emb action_module_path = os.path.join(save_dir, "action_module_state.bin") torch.save(self.action_module.state_dict(), action_module_path) # save state file for final nn final_module_path = os.path.join(save_dir, "final_module_state.bin") torch.save(self.final_module.state_dict(), final_module_path) def get_parameters(self): parameters = list(self.image_module.parameters()) parameters += list(self.text_module.parameters()) parameters += list(self.action_module.parameters()) parameters += list(self.final_module.parameters()) return parameters
class IncrementalModelRecurrentPolicyNetworkResnet(AbstractIncrementalModel): def __init__(self, config, constants): AbstractIncrementalModel.__init__(self, config, constants) self.none_action = config["num_actions"] self.image_module = ImageResnetModule( image_emb_size=constants["image_emb_dim"], input_num_channels=3, image_height=config["image_height"], image_width=config["image_width"], using_recurrence=True) self.num_cameras = 1 self.image_recurrence_module = IncrementalRecurrenceSimpleModule( input_emb_dim=(constants["image_emb_dim"] * self.num_cameras + constants["action_emb_dim"]), output_emb_dim=constants["image_emb_dim"]) if config["use_pointer_model"]: self.text_module = TextPointerModule( emb_dim=constants["word_emb_dim"], hidden_dim=constants["lstm_emb_dim"], vocab_size=config["vocab_size"]) else: self.text_module = TextBiLSTMModule( emb_dim=constants["word_emb_dim"], hidden_dim=constants["lstm_emb_dim"], vocab_size=config["vocab_size"]) self.action_module = ActionSimpleModule( num_actions=config["num_actions"], action_emb_size=constants["action_emb_dim"]) if config["use_pointer_model"]: total_emb_size = (constants["image_emb_dim"] + 4 * constants["lstm_emb_dim"] + constants["action_emb_dim"]) else: total_emb_size = ((self.num_cameras + 1) * constants["image_emb_dim"] + 2 * constants["lstm_emb_dim"] + constants["action_emb_dim"]) if config["do_action_prediction"]: self.action_prediction_module = ActionPredictionModule( 2 * self.num_cameras * constants["image_emb_dim"], constants["image_emb_dim"], config["num_actions"]) else: self.action_prediction_module = None if config["do_temporal_autoencoding"]: self.temporal_autoencoder_module = TemporalAutoencoderModule( self.action_module, self.num_cameras * constants["image_emb_dim"], constants["action_emb_dim"], constants["image_emb_dim"]) else: self.temporal_autoencoder_module = None if config["do_object_detection"]: self.landmark_names = get_all_landmark_names() self.object_detection_module = ObjectDetectionModule( image_module=self.image_module, image_emb_size=self.num_cameras * constants["image_emb_dim"], num_objects=67) else: self.object_detection_module = None if config["do_symbolic_language_prediction"]: self.symbolic_language_prediction_module = SymbolicLanguagePredictionModule( total_emb_size=2 * constants["lstm_emb_dim"]) else: self.symbolic_language_prediction_module = None if config["do_goal_prediction"]: self.goal_prediction_module = GoalPredictionModule( total_emb_size=32) else: self.goal_prediction_module = None final_module = TmpIncrementalMultimodalDenseValtsRecurrentSimpleModule( image_module=self.image_module, image_recurrence_module=self.image_recurrence_module, text_module=self.text_module, action_module=self.action_module, total_emb_size=total_emb_size, num_actions=config["num_actions"]) self.final_module = final_module if torch.cuda.is_available(): self.image_module.cuda() self.image_recurrence_module.cuda() self.text_module.cuda() self.action_module.cuda() self.final_module.cuda() if self.action_prediction_module is not None: self.action_prediction_module.cuda() if self.temporal_autoencoder_module is not None: self.temporal_autoencoder_module.cuda() if self.object_detection_module is not None: self.object_detection_module.cuda() if self.symbolic_language_prediction_module is not None: self.symbolic_language_prediction_module.cuda() if self.goal_prediction_module is not None: self.goal_prediction_module.cuda() def get_probs_batch(self, agent_observed_state_list, mode=None): raise AssertionError("Buggy") for aos in agent_observed_state_list: assert isinstance(aos, AgentObservedState) # print "batch size:", len(agent_observed_state_list) # sort list by instruction length agent_observed_state_list = sorted( agent_observed_state_list, key=lambda aos_: len(aos_.get_instruction()), reverse=True ) image_seq_lens = [aos.get_num_images() for aos in agent_observed_state_list] image_seq_lens_batch = cuda_tensor( torch.from_numpy(np.array(image_seq_lens))) max_len = max(image_seq_lens) image_seqs = [aos.get_image()[:max_len] for aos in agent_observed_state_list] image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float()) instructions = [aos.get_instruction() for aos in agent_observed_state_list] read_pointers = [aos.get_read_pointers() for aos in agent_observed_state_list] instructions_batch = (instructions, read_pointers) prev_actions_raw = [aos.get_previous_action() for aos in agent_observed_state_list] prev_actions = [self.none_action if a is None else a for a in prev_actions_raw] prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions))) probs_batch, _ = self.final_module(image_batch, image_seq_lens_batch, instructions_batch, prev_actions_batch, mode, model_state=None) return probs_batch def get_probs(self, agent_observed_state, model_state, mode=None, volatile=False): assert isinstance(agent_observed_state, AgentObservedState) agent_observed_state_list = [agent_observed_state] image_seq_lens = [1] image_seq_lens_batch = cuda_tensor( torch.from_numpy(np.array(image_seq_lens))) # max_len = max(image_seq_lens) # image_seqs = [aos.get_image()[:max_len] # for aos in agent_observed_state_list] image_seqs = [[aos.get_last_image()] for aos in agent_observed_state_list] image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float(), volatile) instructions = [aos.get_instruction() for aos in agent_observed_state_list] read_pointers = [aos.get_read_pointers() for aos in agent_observed_state_list] instructions_batch = (instructions, read_pointers) prev_actions_raw = [aos.get_previous_action() for aos in agent_observed_state_list] prev_actions = [self.none_action if a is None else a for a in prev_actions_raw] prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)), volatile) probs_batch, new_model_state, image_emb_seq, state_feature = self.final_module( image_batch, image_seq_lens_batch, instructions_batch, prev_actions_batch, mode, model_state) return probs_batch, new_model_state, image_emb_seq, state_feature def action_prediction_log_prob(self, batch_input): assert self.action_prediction_module is not None, "Action prediction module not created. Check config." return self.action_prediction_module(batch_input) def predict_action_result(self, batch_image_feature, action_batch): assert self.temporal_autoencoder_module is not None, "Temporal action module not created. Check config." return self.temporal_autoencoder_module(batch_image_feature, action_batch) def predict_goal_result(self, batch_state_feature): assert self.goal_prediction_module is not None, "Goal Prediction module not created. Check config." return self.goal_prediction_module(batch_state_feature) def get_probs_and_visible_objects(self, agent_observed_state_list, batch_image_feature): assert self.object_detection_module is not None, "Object detection module not created. Check config." landmarks_visible = [] for aos in agent_observed_state_list: x_pos, z_pos, y_angle = aos.get_position_orientation() landmark_pos_dict = aos.get_landmark_pos_dict() visible_landmarks_dict = self.object_detection_module.get_visible_landmark_r_theta( x_pos, z_pos, y_angle, landmark_pos_dict) landmarks_visible.append(visible_landmarks_dict) # shape is BATCH_SIZE x num objects x 2 landmark_log_prob, distance_log_prob, theta_log_prob = self.object_detection_module(batch_image_feature) # landmarks_visible is list of length BATCH_SIZE, each item is a set containing landmark indices return landmark_log_prob, distance_log_prob, theta_log_prob, landmarks_visible def get_language_prediction_probs(self, batch_input): assert self.symbolic_language_prediction_module is not None, \ "Language prediction module not created. Check config." return self.symbolic_language_prediction_module(batch_input) def init_weights(self): self.text_module.init_weights() self.image_recurrence_module.init_weights() self.image_module.init_weights() def share_memory(self): self.image_module.share_memory() self.image_recurrence_module.share_memory() self.text_module.share_memory() self.action_module.share_memory() self.final_module.share_memory() if self.action_prediction_module is not None: self.action_prediction_module.share_memory() if self.temporal_autoencoder_module is not None: self.temporal_autoencoder_module.share_memory() if self.object_detection_module is not None: self.object_detection_module.share_memory() if self.symbolic_language_prediction_module is not None: self.symbolic_language_prediction_module.share_memory() if self.goal_prediction_module is not None: self.goal_prediction_module.share_memory() def get_state_dict(self): nested_state_dict = dict() nested_state_dict["image_module"] = self.image_module.state_dict() nested_state_dict["image_recurrence_module"] = self.image_recurrence_module.state_dict() nested_state_dict["text_module"] = self.text_module.state_dict() nested_state_dict["action_module"] = self.action_module.state_dict() nested_state_dict["final_module"] = self.final_module.state_dict() if self.action_prediction_module is not None: nested_state_dict["ap_module"] = self.action_prediction_module.state_dict() if self.temporal_autoencoder_module is not None: nested_state_dict["tae_module"] = self.temporal_autoencoder_module.state_dict() if self.object_detection_module is not None: nested_state_dict["od_module"] = self.object_detection_module.state_dict() if self.symbolic_language_prediction_module is not None: nested_state_dict["sym_lang_module"] = self.symbolic_language_prediction_module.state_dict() if self.goal_prediction_module is not None: nested_state_dict["goal_pred_module"] = self.goal_prediction_module.state_dict() return nested_state_dict def load_from_state_dict(self, nested_state_dict): self.image_module.load_state_dict(nested_state_dict["image_module"]) self.image_recurrence_module.load_state_dict(nested_state_dict["image_recurrence_module"]) self.text_module.load_state_dict(nested_state_dict["text_module"]) self.action_module.load_state_dict(nested_state_dict["action_module"]) self.final_module.load_state_dict(nested_state_dict["final_module"]) if self.action_prediction_module is not None: self.action_prediction_module.load_state_dict(nested_state_dict["ap_module"]) if self.temporal_autoencoder_module is not None: self.temporal_autoencoder_module.load_state_dict(nested_state_dict["tae_module"]) if self.object_detection_module is not None: self.object_detection_module.load_state_dict(nested_state_dict["od_module"]) if self.symbolic_language_prediction_module is not None: self.symbolic_language_prediction_module.load_state_dict(nested_state_dict["sym_lang_module"]) if self.goal_prediction_module is not None: self.goal_prediction_module.load_state_dict(nested_state_dict["goal_pred_module"]) def load_resnet_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) def fix_resnet(self): self.image_module.fix_resnet() def load_lstm_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) text_module_path = os.path.join(load_dir, "text_module_state.bin") self.text_module.load_state_dict(torch_load(text_module_path)) def load_saved_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) image_recurrence_module_path = os.path.join( load_dir, "image_recurrence_module_state.bin") self.image_recurrence_module.load_state_dict( torch_load(image_recurrence_module_path)) text_module_path = os.path.join(load_dir, "text_module_state.bin") self.text_module.load_state_dict(torch_load(text_module_path)) action_module_path = os.path.join(load_dir, "action_module_state.bin") self.action_module.load_state_dict(torch_load(action_module_path)) final_module_path = os.path.join(load_dir, "final_module_state.bin") self.final_module.load_state_dict(torch_load(final_module_path)) if self.action_prediction_module is not None: auxiliary_action_prediction_path = os.path.join(load_dir, "auxiliary_action_prediction.bin") self.action_prediction_module.load_state_dict(torch_load(auxiliary_action_prediction_path)) if self.temporal_autoencoder_module is not None: auxiliary_temporal_autoencoder_path = os.path.join(load_dir, "auxiliary_temporal_autoencoder.bin") self.temporal_autoencoder_module.load_state_dict(torch_load(auxiliary_temporal_autoencoder_path)) if self.object_detection_module is not None: auxiliary_object_detection_path = os.path.join(load_dir, "auxiliary_object_detection.bin") self.object_detection_module.load_state_dict(torch_load(auxiliary_object_detection_path)) if self.symbolic_language_prediction_module is not None: auxiliary_symbolic_language_prediction_path = os.path.join( load_dir, "auxiliary_symbolic_language_prediction.bin") self.symbolic_language_prediction_module.load_state_dict( torch_load(auxiliary_symbolic_language_prediction_path)) if self.goal_prediction_module is not None: auxiliary_goal_prediction_path = os.path.join(load_dir, "auxiliary_goal_prediction.bin") self.goal_prediction_module.load_state_dict(torch_load(auxiliary_goal_prediction_path)) def save_model(self, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) # save state file for image nn image_module_path = os.path.join(save_dir, "image_module_state.bin") torch.save(self.image_module.state_dict(), image_module_path) # save state file for image recurrence nn image_recurrence_module_path = os.path.join( save_dir, "image_recurrence_module_state.bin") torch.save(self.image_recurrence_module.state_dict(), image_recurrence_module_path) # save state file for text nn text_module_path = os.path.join(save_dir, "text_module_state.bin") torch.save(self.text_module.state_dict(), text_module_path) # save state file for action emb action_module_path = os.path.join(save_dir, "action_module_state.bin") torch.save(self.action_module.state_dict(), action_module_path) # save state file for final nn final_module_path = os.path.join(save_dir, "final_module_state.bin") torch.save(self.final_module.state_dict(), final_module_path) # save the auxiliary models if self.action_prediction_module is not None: auxiliary_action_prediction_path = os.path.join(save_dir, "auxiliary_action_prediction.bin") torch.save(self.action_prediction_module.state_dict(), auxiliary_action_prediction_path) if self.temporal_autoencoder_module is not None: auxiliary_temporal_autoencoder_path = os.path.join(save_dir, "auxiliary_temporal_autoencoder.bin") torch.save(self.temporal_autoencoder_module.state_dict(), auxiliary_temporal_autoencoder_path) if self.object_detection_module is not None: auxiliary_object_detection_path = os.path.join(save_dir, "auxiliary_object_detection.bin") torch.save(self.object_detection_module.state_dict(), auxiliary_object_detection_path) if self.symbolic_language_prediction_module is not None: auxiliary_symbolic_language_prediction_path = os.path.join( save_dir, "auxiliary_symbolic_language_prediction.bin") torch.save(self.symbolic_language_prediction_module.state_dict(), auxiliary_symbolic_language_prediction_path) if self.goal_prediction_module is not None: auxiliary_goal_prediction_path = os.path.join(save_dir, "auxiliary_goal_prediction.bin") torch.save(self.goal_prediction_module.state_dict(), auxiliary_goal_prediction_path) def get_parameters(self): parameters = list(self.image_module.parameters()) parameters += list(self.image_recurrence_module.parameters()) parameters += list(self.text_module.parameters()) parameters += list(self.action_module.parameters()) parameters += list(self.final_module.parameters()) if self.action_prediction_module is not None: parameters += list(self.action_prediction_module.parameters()) if self.temporal_autoencoder_module is not None: parameters += list(self.temporal_autoencoder_module.parameters()) if self.object_detection_module is not None: parameters += list(self.object_detection_module.parameters()) if self.symbolic_language_prediction_module is not None: parameters += list(self.symbolic_language_prediction_module.parameters()) if self.goal_prediction_module is not None: parameters += list(self.goal_prediction_module.parameters()) return parameters def get_named_parameters(self): named_parameters = list(self.image_module.named_parameters()) named_parameters += list(self.image_recurrence_module.named_parameters()) named_parameters += list(self.text_module.named_parameters()) named_parameters += list(self.action_module.named_parameters()) named_parameters += list(self.final_module.named_parameters()) if self.action_prediction_module is not None: named_parameters += list(self.action_prediction_module.named_parameters()) if self.temporal_autoencoder_module is not None: named_parameters += list(self.temporal_autoencoder_module.named_parameters()) if self.object_detection_module is not None: named_parameters += list(self.object_detection_module.named_parameters()) if self.symbolic_language_prediction_module is not None: named_parameters += list(self.symbolic_language_prediction_module.named_parameters()) if self.goal_prediction_module is not None: named_parameters += list(self.goal_prediction_module.named_parameters()) return named_parameters
class IncrementalModelEmnlp(AbstractIncrementalModel): def __init__(self, config, constants): AbstractIncrementalModel.__init__(self, config, constants) self.none_action = config["num_actions"] self.config = config self.constants = constants # CNN over images - using what is essentially SimpleImage currently self.image_module = ImageCnnEmnlp( image_emb_size=constants["image_emb_dim"], input_num_channels=3 * 5, # 3 channels per image - 5 images in history image_height=config["image_height"], image_width=config["image_width"]) # LSTM to embed text self.text_module = TextSimpleModule( emb_dim=constants["word_emb_dim"], hidden_dim=constants["lstm_emb_dim"], vocab_size=config["vocab_size"]) # Action module to embed previous action+block self.action_module = ActionSimpleModule( num_actions=config["num_actions"], action_emb_size=constants["action_emb_dim"]) # Put it all together self.final_module = IncrementalMultimodalEmnlp( image_module=self.image_module, text_module=self.text_module, action_module=self.action_module, input_embedding_size=constants["lstm_emb_dim"] + constants["image_emb_dim"] + constants["action_emb_dim"], output_hidden_size=config["h1_hidden_dim"], blocks_hidden_size=config["blocks_hidden_dim"], directions_hidden_size=config["action_hidden_dim"], max_episode_length=(constants["horizon"] + 5)) if torch.cuda.is_available(): self.image_module.cuda() self.text_module.cuda() self.action_module.cuda() self.final_module.cuda() def get_probs_batch(self, agent_observed_state_list, mode=None): raise NotImplementedError() def get_probs(self, agent_observed_state, model_state, mode=None, volatile=False): assert isinstance(agent_observed_state, AgentObservedState) # Image list is already padded with zero-images if <5 images are available images = agent_observed_state.get_image()[-5:] image_batch = cuda_var( torch.from_numpy(np.array(images)).float(), volatile) # Flatten them? TODO: maybe don't hardcode this later on? batch size is 1 ;) image_batch = image_batch.view(1, 15, self.config["image_height"], self.config["image_width"]) # List of instructions. False is there because it expects a second argument. TODO: figure out what this is instructions_batch = ([agent_observed_state.get_instruction()], False) # Previous action prev_actions_raw = [agent_observed_state.get_previous_action()] # If previous action is non-existant then encode that as a stop? prev_actions = [ self.none_action if a is None else a for a in prev_actions_raw ] prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions))) # Get probabilities probs_batch, new_model_state = self.final_module( image_batch, instructions_batch, prev_actions_batch, model_state) # last two we don't really need... return probs_batch, new_model_state, None, None def init_weights(self): self.text_module.init_weights() self.image_module.init_weights() self.action_module.init_weights() self.final_module.init_weights() def share_memory(self): self.image_module.share_memory() self.text_module.share_memory() self.action_module.share_memory() self.final_module.share_memory() def get_state_dict(self): nested_state_dict = dict() nested_state_dict["image_module"] = self.image_module.state_dict() nested_state_dict["text_module"] = self.text_module.state_dict() nested_state_dict["action_module"] = self.action_module.state_dict() nested_state_dict["final_module"] = self.final_module.state_dict() return nested_state_dict def load_from_state_dict(self, nested_state_dict): self.image_module.load_state_dict(nested_state_dict["image_module"]) self.text_module.load_state_dict(nested_state_dict["text_module"]) self.action_module.load_state_dict(nested_state_dict["action_module"]) self.final_module.load_state_dict(nested_state_dict["final_module"]) def load_resnet_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) def load_lstm_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) text_module_path = os.path.join(load_dir, "text_module_state.bin") self.text_module.load_state_dict(torch_load(text_module_path)) def load_saved_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) action_module_path = os.path.join(load_dir, "action_module_state.bin") self.action_module.load_state_dict(torch_load(action_module_path)) text_module_path = os.path.join(load_dir, "text_module_state.bin") self.text_module.load_state_dict(torch_load(text_module_path)) # action_module_path = os.path.join(load_dir, "action_module_state.bin") # self.action_module.load_state_dict(torch_load(action_module_path)) final_module_path = os.path.join(load_dir, "final_module_state.bin") self.final_module.load_state_dict(torch_load(final_module_path)) def save_model(self, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) # save state file for image nn image_module_path = os.path.join(save_dir, "image_module_state.bin") torch.save(self.image_module.state_dict(), image_module_path) # save state file for image recurrence nn action_module_path = os.path.join(save_dir, "action_module_state.bin") #torch.save(self.action_module.state_dict(), # action_module_path) torch.save(self.action_module.state_dict(), action_module_path) # save state file for text nn text_module_path = os.path.join(save_dir, "text_module_state.bin") torch.save(self.text_module.state_dict(), text_module_path) # save state file for action emb # action_module_path = os.path.join(save_dir, "action_module_state.bin") # torch.save(self.action_module.state_dict(), action_module_path) # save state file for final nn final_module_path = os.path.join(save_dir, "final_module_state.bin") torch.save(self.final_module.state_dict(), final_module_path) def get_parameters(self): # parameters = list(self.image_module.parameters()) # parameters += list(self.action_module.parameters()) # parameters += list(self.text_module.parameters()) parameters = list(self.final_module.parameters()) return parameters def get_named_parameters(self): # named_parameters = list(self.image_module.named_parameters()) # named_parameters += list(self.action_module.named_parameters()) # named_parameters += list(self.text_module.named_parameters()) named_parameters = list(self.final_module.named_parameters()) return named_parameters
class ModelPolicyNetworkSymbolicText(AbstractModel): def __init__(self, config, constants): AbstractModel.__init__(self, config, constants) self.none_action = config["num_actions"] landmark_names = get_all_landmark_names() self.radius_module = RadiusModule(15) self.angle_module = AngleModule(48) self.landmark_module = LandmarkModule(63) self.image_module = ImageResnetModule( image_emb_size=constants["image_emb_dim"], input_num_channels=3, image_height=config["image_height"], image_width=config["image_width"], using_recurrence=True) self.image_recurrence_module = RecurrenceSimpleModule( input_emb_dim=constants["image_emb_dim"], output_emb_dim=constants["image_emb_dim"]) self.text_module = SymbolicInstructionModule( radius_embedding=self.radius_module, theta_embedding=self.angle_module, landmark_embedding=self.landmark_module) self.action_module = ActionSimpleModule( num_actions=config["num_actions"], action_emb_size=constants["action_emb_dim"]) total_emb_size = (constants["image_emb_dim"] + 32 * 4 + constants["action_emb_dim"]) final_module = MultimodalRecurrentSimpleModule( image_module=self.image_module, image_recurrence_module=self.image_recurrence_module, text_module=self.text_module, action_module=self.action_module, total_emb_size=total_emb_size, num_actions=config["num_actions"]) self.final_module = final_module if torch.cuda.is_available(): self.image_module.cuda() self.text_module.cuda() self.action_module.cuda() self.final_module.cuda() self.radius_module.cuda() self.angle_module.cuda() self.landmark_module.cuda() def get_probs_batch(self, agent_observed_state_list, mode=None): for aos in agent_observed_state_list: assert isinstance(aos, AgentObservedState) # print "batch size:", len(agent_observed_state_list) # sort list by instruction length agent_observed_state_list = sorted( agent_observed_state_list, key=lambda aos_: len(aos_.get_instruction()), reverse=True ) image_seq_lens = [aos.get_num_images() for aos in agent_observed_state_list] image_seq_lens_batch = cuda_tensor( torch.from_numpy(np.array(image_seq_lens))) max_len = max(image_seq_lens) image_seqs = [aos.get_image()[:max_len] for aos in agent_observed_state_list] image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float()) instructions_batch = [aos.get_symbolic_instruction() for aos in agent_observed_state_list] prev_actions_raw = [aos.get_previous_action() for aos in agent_observed_state_list] prev_actions = [self.none_action if a is None else a for a in prev_actions_raw] prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions))) probs_batch = self.final_module(image_batch, image_seq_lens_batch, instructions_batch, prev_actions_batch, mode) return probs_batch def load_saved_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) text_module_path = os.path.join(load_dir, "text_module_state.bin") self.text_module.load_state_dict(torch_load(text_module_path)) action_module_path = os.path.join(load_dir, "action_module_state.bin") self.action_module.load_state_dict(torch_load(action_module_path)) final_module_path = os.path.join(load_dir, "final_module_state.bin") self.final_module.load_state_dict(torch_load(final_module_path)) radius_module_path = os.path.join(load_dir, "radius_module_state.bin") self.radius_module.load_state_dict(torch_load(radius_module_path)) angle_module_path = os.path.join(load_dir, "angle_module_state.bin") self.angle_module.load_state_dict(torch_load(angle_module_path)) landmark_module_path = os.path.join(load_dir, "landmark_module_state.bin") self.landmark_module.load_state_dict(torch_load(landmark_module_path)) def save_model(self, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) # save state file for image nn image_module_path = os.path.join(save_dir, "image_module_state.bin") torch.save(self.image_module.state_dict(), image_module_path) # save state file for text nn text_module_path = os.path.join(save_dir, "text_module_state.bin") torch.save(self.text_module.state_dict(), text_module_path) # save state file for action emb action_module_path = os.path.join(save_dir, "action_module_state.bin") torch.save(self.action_module.state_dict(), action_module_path) # save state file for final nn final_module_path = os.path.join(save_dir, "final_module_state.bin") torch.save(self.final_module.state_dict(), final_module_path) # save state file for radius nn radius_module_path = os.path.join(save_dir, "radius_module_state.bin") torch.save(self.radius_module.state_dict(), radius_module_path) # save state file for angle nn angle_module_path = os.path.join(save_dir, "angle_module_state.bin") torch.save(self.angle_module.state_dict(), angle_module_path) # save state file for landmark nn landmark_module_path = os.path.join(save_dir, "landmark_module_state.bin") torch.save(self.landmark_module.state_dict(), landmark_module_path) def get_parameters(self): parameters = list(self.image_module.parameters()) parameters += list(self.text_module.parameters()) parameters += list(self.action_module.parameters()) parameters += list(self.final_module.parameters()) parameters += list(self.radius_module.parameters()) parameters += list(self.angle_module.parameters()) parameters += list(self.landmark_module.parameters()) return parameters
class IncrementalModelRecurrentPolicyNetworkSymbolicTextWithLSTMResnet( AbstractIncrementalModel): def __init__(self, config, constants): AbstractIncrementalModel.__init__(self, config, constants) self.none_action = config["num_actions"] landmark_names = get_all_landmark_names() self.radius_module = RadiusModule(15) self.angle_module = AngleModule(12) # (48) self.landmark_module = LandmarkModule(67) self.num_cameras = 1 self.image_module = ImageRyanResnetModule( image_emb_size=constants["image_emb_dim"], input_num_channels=3, image_height=config["image_height"], image_width=config["image_width"], using_recurrence=True) self.image_recurrence_module = IncrementalRecurrenceSimpleModule( input_emb_dim=constants["image_emb_dim"] * self.num_cameras, # + constants["action_emb_dim"], output_emb_dim=constants["image_emb_dim"]) self.symbolic_text_module = SymbolicInstructionModule( radius_embedding=self.radius_module, theta_embedding=self.angle_module, landmark_embedding=self.landmark_module) self.lstm_text_module = TextSimpleModule( emb_dim=constants["word_emb_dim"], hidden_dim=constants["lstm_emb_dim"], vocab_size=config["vocab_size"]) self.action_module = ActionSimpleModule( num_actions=config["num_actions"], action_emb_size=constants["action_emb_dim"]) total_emb_size = ((self.num_cameras) * constants["image_emb_dim"] + 32 * 2 + constants["lstm_emb_dim"] + +constants["action_emb_dim"]) if config["do_action_prediction"]: self.action_prediction_module = ActionPredictionModule( 2 * self.num_cameras * constants["image_emb_dim"], constants["image_emb_dim"], config["num_actions"]) else: self.action_prediction_module = None if config["do_temporal_autoencoding"]: self.temporal_autoencoder_module = TemporalAutoencoderModule( self.action_module, self.num_cameras * constants["image_emb_dim"], constants["action_emb_dim"], constants["image_emb_dim"]) else: self.temporal_autoencoder_module = None if config["do_object_detection"]: self.landmark_names = get_all_landmark_names() self.object_detection_module = ObjectDetectionModule( image_module=self.image_module, image_emb_size=self.num_cameras * constants["image_emb_dim"], num_objects=67) else: self.object_detection_module = None final_module = IncrementalMultimodalMixedTextRecurrentSimpleModule( image_module=self.image_module, image_recurrence_module=self.image_recurrence_module, symbolic_text_module=self.symbolic_text_module, lstm_text_module=self.lstm_text_module, action_module=self.action_module, total_emb_size=total_emb_size, num_actions=config["num_actions"]) self.final_module = final_module if torch.cuda.is_available(): self.image_module.cuda() self.image_recurrence_module.cuda() self.symbolic_text_module.cuda() self.lstm_text_module.cuda() self.action_module.cuda() self.final_module.cuda() if self.action_prediction_module is not None: self.action_prediction_module.cuda() if self.temporal_autoencoder_module is not None: self.temporal_autoencoder_module.cuda() if self.object_detection_module is not None: self.object_detection_module.cuda() def get_probs_batch(self, agent_observed_state_list, mode=None): raise NotImplementedError() for aos in agent_observed_state_list: assert isinstance(aos, AgentObservedState) # print "batch size:", len(agent_observed_state_list) # sort list by instruction length agent_observed_state_list = sorted( agent_observed_state_list, key=lambda aos_: len(aos_.get_instruction()), reverse=True) image_seq_lens = [ aos.get_num_images() for aos in agent_observed_state_list ] image_seq_lens_batch = cuda_tensor( torch.from_numpy(np.array(image_seq_lens))) max_len = max(image_seq_lens) image_seqs = [ aos.get_image()[:max_len] for aos in agent_observed_state_list ] image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float()) instructions = [ aos.get_instruction() for aos in agent_observed_state_list ] read_pointers = [ aos.get_read_pointers() for aos in agent_observed_state_list ] instructions_batch = (instructions, read_pointers) prev_actions_raw = [ aos.get_previous_action() for aos in agent_observed_state_list ] prev_actions = [ self.none_action if a is None else a for a in prev_actions_raw ] prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions))) probs_batch, _ = self.final_module(image_batch, image_seq_lens_batch, instructions_batch, prev_actions_batch, mode, model_state=None) return probs_batch def get_probs(self, agent_observed_state, model_state, mode=None, volatile=False): assert isinstance(agent_observed_state, AgentObservedState) agent_observed_state_list = [agent_observed_state] image_seq_lens = [1] image_seq_lens_batch = cuda_tensor( torch.from_numpy(np.array(image_seq_lens))) image_seqs = [[aos.get_last_image()] for aos in agent_observed_state_list] image_batch = cuda_var( torch.from_numpy(np.array(image_seqs)).float(), volatile) instructions_batch = [ aos.get_symbolic_instruction() for aos in agent_observed_state_list ] instructions = [ aos.get_instruction() for aos in agent_observed_state_list ] read_pointers = [ aos.get_read_pointers() for aos in agent_observed_state_list ] lstm_instructions_batch = (instructions, read_pointers) prev_actions_raw = [ aos.get_previous_action() for aos in agent_observed_state_list ] prev_actions = [ self.none_action if a is None else a for a in prev_actions_raw ] prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)), volatile) probs_batch, new_model_state, image_emb_seq, state_feature = self.final_module( image_batch, image_seq_lens_batch, instructions_batch, lstm_instructions_batch, prev_actions_batch, mode, model_state) return probs_batch, new_model_state, image_emb_seq, state_feature def get_probs_symbolic_text(self, agent_observed_state, symbolic_text, model_state, mode=None, volatile=False): """ Same as get_probs instead forces the model to use the given symbolic text """ assert isinstance(agent_observed_state, AgentObservedState) agent_observed_state_list = [agent_observed_state] image_seq_lens = [1] image_seq_lens_batch = cuda_tensor( torch.from_numpy(np.array(image_seq_lens))) image_seqs = [[aos.get_last_image()] for aos in agent_observed_state_list] image_batch = cuda_var( torch.from_numpy(np.array(image_seqs)).float(), volatile) instructions_batch = [symbolic_text] instructions = [ aos.get_instruction() for aos in agent_observed_state_list ] read_pointers = [ aos.get_read_pointers() for aos in agent_observed_state_list ] real_instructions_batch = (instructions, read_pointers) prev_actions_raw = [ aos.get_previous_action() for aos in agent_observed_state_list ] prev_actions = [ self.none_action if a is None else a for a in prev_actions_raw ] prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)), volatile) probs_batch, new_model_state, image_emb_seq, state_feature = self.final_module( image_batch, image_seq_lens_batch, instructions_batch, real_instructions_batch, prev_actions_batch, mode, model_state) return probs_batch, new_model_state, image_emb_seq, state_feature def action_prediction_log_prob(self, batch_input): assert self.action_prediction_module is not None, "Action prediction module not created. Check config." return self.action_prediction_module(batch_input) def predict_action_result(self, batch_image_feature, action_batch): assert self.temporal_autoencoder_module is not None, "Temporal action module not created. Check config." return self.temporal_autoencoder_module(batch_image_feature, action_batch) def get_probs_and_visible_objects(self, agent_observed_state_list, batch_image_feature): assert self.object_detection_module is not None, "Object detection module not created. Check config." landmarks_visible = [] for aos in agent_observed_state_list: x_pos, z_pos, y_angle = aos.get_position_orientation() landmark_pos_dict = aos.get_landmark_pos_dict() visible_landmarks_dict = self.object_detection_module.get_visible_landmark_r_theta( x_pos, z_pos, y_angle, landmark_pos_dict) landmarks_visible.append(visible_landmarks_dict) # shape is BATCH_SIZE x num objects x 2 landmark_log_prob, distance_log_prob, theta_log_prob = self.object_detection_module( batch_image_feature) # landmarks_visible is list of length BATCH_SIZE, each item is a set containing landmark indices return landmark_log_prob, distance_log_prob, theta_log_prob, landmarks_visible def init_weights(self): self.image_module.init_weights() self.image_recurrence_module.init_weights() def load_resnet_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) def load_saved_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) image_recurrence_module_path = os.path.join( load_dir, "image_recurrence_module_state.bin") self.image_recurrence_module.load_state_dict( torch_load(image_recurrence_module_path)) symbolic_text_module_path = os.path.join( load_dir, "symbolic_text_module_state.bin") self.symbolic_text_module.load_state_dict( torch_load(symbolic_text_module_path)) lstm_text_module_path = os.path.join(load_dir, "lstm_text_module_state.bin") self.lstm_text_module.load_state_dict( torch_load(lstm_text_module_path)) action_module_path = os.path.join(load_dir, "action_module_state.bin") self.action_module.load_state_dict(torch_load(action_module_path)) final_module_path = os.path.join(load_dir, "final_module_state.bin") self.final_module.load_state_dict(torch_load(final_module_path)) if self.action_prediction_module is not None: auxiliary_action_prediction_path = os.path.join( load_dir, "auxiliary_action_prediction.bin") self.action_prediction_module.load_state_dict( torch_load(auxiliary_action_prediction_path)) if self.temporal_autoencoder_module is not None: auxiliary_temporal_autoencoder_path = os.path.join( load_dir, "auxiliary_temporal_autoencoder.bin") self.temporal_autoencoder_module.load_state_dict( torch_load(auxiliary_temporal_autoencoder_path)) if self.object_detection_module is not None: auxiliary_object_detection_path = os.path.join( load_dir, "auxiliary_object_detection.bin") self.object_detection_module.load_state_dict( torch_load(auxiliary_object_detection_path)) def save_model(self, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) # save state file for image nn image_module_path = os.path.join(save_dir, "image_module_state.bin") torch.save(self.image_module.state_dict(), image_module_path) # save state file for image recurrence nn image_recurrence_module_path = os.path.join( save_dir, "image_recurrence_module_state.bin") torch.save(self.image_recurrence_module.state_dict(), image_recurrence_module_path) # save state file for text nn symbolic_text_module_path = os.path.join( save_dir, "symbolic_text_module_state.bin") torch.save(self.symbolic_text_module.state_dict(), symbolic_text_module_path) lstm_text_module_path = os.path.join(save_dir, "lstm_text_module_state.bin") torch.save(self.lstm_text_module.state_dict(), lstm_text_module_path) # save state file for action emb action_module_path = os.path.join(save_dir, "action_module_state.bin") torch.save(self.action_module.state_dict(), action_module_path) # save state file for final nn final_module_path = os.path.join(save_dir, "final_module_state.bin") torch.save(self.final_module.state_dict(), final_module_path) # save the auxiliary models if self.action_prediction_module is not None: auxiliary_action_prediction_path = os.path.join( save_dir, "auxiliary_action_prediction.bin") torch.save(self.action_prediction_module.state_dict(), auxiliary_action_prediction_path) if self.temporal_autoencoder_module is not None: auxiliary_temporal_autoencoder_path = os.path.join( save_dir, "auxiliary_temporal_autoencoder.bin") torch.save(self.temporal_autoencoder_module.state_dict(), auxiliary_temporal_autoencoder_path) if self.object_detection_module is not None: auxiliary_object_detection_path = os.path.join( save_dir, "auxiliary_object_detection.bin") torch.save(self.object_detection_module.state_dict(), auxiliary_object_detection_path) def get_parameters(self): # parameters = list(self.image_module.parameters()) parameters = list(self.image_recurrence_module.parameters()) parameters += list(self.symbolic_text_module.parameters()) parameters += list(self.lstm_text_module.parameters()) parameters += list(self.action_module.parameters()) parameters += list(self.final_module.parameters()) if self.action_prediction_module is not None: parameters += list(self.action_prediction_module.parameters()) if self.temporal_autoencoder_module is not None: parameters += list(self.temporal_autoencoder_module.parameters()) if self.object_detection_module is not None: parameters += list(self.object_detection_module.parameters()) return parameters def get_named_parameters(self): # named_parameters = list(self.image_module.named_parameters()) named_parameters = list( self.image_recurrence_module.named_parameters()) named_parameters += list(self.symbolic_text_module.named_parameters()) named_parameters += list(self.lstm_text_module.named_parameters()) named_parameters += list(self.action_module.named_parameters()) named_parameters += list(self.final_module.named_parameters()) if self.action_prediction_module is not None: named_parameters += list( self.action_prediction_module.named_parameters()) if self.temporal_autoencoder_module is not None: named_parameters += list( self.temporal_autoencoder_module.named_parameters()) if self.object_detection_module is not None: named_parameters += list( self.object_detection_module.named_parameters()) '''if self.symbolic_language_prediction_module is not None: named_parameters += list(self.symbolic_language_prediction_module.named_parameters())''' return named_parameters
class IncrementalModelRecurrentPolicyNetworkGoalImageResnet(AbstractIncrementalModel): def __init__(self, config, constants): AbstractIncrementalModel.__init__(self, config, constants) self.none_action = config["num_actions"] self.image_module = ImageResnetModule( image_emb_size=constants["image_emb_dim"], input_num_channels=3, image_height=config["image_height"], image_width=config["image_width"], using_recurrence=True) # self.image_module = resnet.resnet18(pretrained=True) # constants["image_emb_dim"] = 1000 self.image_recurrence_module = IncrementalRecurrenceSimpleModule( input_emb_dim=constants["image_emb_dim"], output_emb_dim=constants["image_emb_dim"]) self.action_module = ActionSimpleModule( num_actions=config["num_actions"], action_emb_size=constants["action_emb_dim"]) total_emb_size = (2 * constants["image_emb_dim"] + constants["action_emb_dim"]) if config["do_action_prediction"]: self.action_prediction_module = ActionPredictionModule( 2 * constants["image_emb_dim"], constants["image_emb_dim"], config["num_actions"]) else: self.action_prediction_module = None if config["do_temporal_autoencoding"]: self.temporal_autoencoder_module = TemporalAutoencoderModule( self.action_module, constants["image_emb_dim"], constants["action_emb_dim"], constants["image_emb_dim"]) else: self.temporal_autoencoder_module = None if config["do_object_detection"]: self.landmark_names = get_all_landmark_names() self.object_detection_module = ObjectDetectionModule( image_module=self.image_module, image_emb_size=constants["image_emb_dim"], num_objects=63) else: self.object_detection_module = None final_module = IncrementalMultimodalRecurrentSimpleGoalImageModule( image_module=self.image_module, image_recurrence_module=self.image_recurrence_module, action_module=self.action_module, total_emb_size=total_emb_size, num_actions=config["num_actions"]) self.final_module = final_module if torch.cuda.is_available(): self.image_module.cuda() self.image_recurrence_module.cuda() self.action_module.cuda() self.final_module.cuda() if self.action_prediction_module is not None: self.action_prediction_module.cuda() if self.temporal_autoencoder_module is not None: self.temporal_autoencoder_module.cuda() if self.object_detection_module is not None: self.object_detection_module.cuda() def get_probs_batch(self, agent_observed_state_list, mode=None): raise NotImplementedError() for aos in agent_observed_state_list: assert isinstance(aos, AgentObservedState) # print "batch size:", len(agent_observed_state_list) # sort list by instruction length agent_observed_state_list = sorted( agent_observed_state_list, key=lambda aos_: len(aos_.get_instruction()), reverse=True ) image_seq_lens = [aos.get_num_images() for aos in agent_observed_state_list] image_seq_lens_batch = cuda_tensor( torch.from_numpy(np.array(image_seq_lens))) max_len = max(image_seq_lens) image_seqs = [aos.get_image()[:max_len] for aos in agent_observed_state_list] image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float()) instructions = [aos.get_instruction() for aos in agent_observed_state_list] read_pointers = [aos.get_read_pointers() for aos in agent_observed_state_list] instructions_batch = (instructions, read_pointers) prev_actions_raw = [aos.get_previous_action() for aos in agent_observed_state_list] prev_actions = [self.none_action if a is None else a for a in prev_actions_raw] prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions))) probs_batch, _ = self.final_module(image_batch, image_seq_lens_batch, instructions_batch, prev_actions_batch, mode, model_state=None) return probs_batch # def resize(self, img): # img = img.swapaxes(0, 1).swapaxes(1, 2) # resized_img = scipy.misc.imresize(img, (224, 224)) # return resized_img.swapaxes(1, 2).swapaxes(0, 1) # # def get_probs(self, agent_observed_state, model_state, mode=None): # # assert isinstance(agent_observed_state, AgentObservedState) # agent_observed_state_list = [agent_observed_state] # # image_seq_lens = [1] # image_seq_lens_batch = cuda_tensor( # torch.from_numpy(np.array(image_seq_lens))) # image_seqs = [self.resize(aos.get_last_image()) # for aos in agent_observed_state_list] # image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float()) # # goal_image_seqs = [self.resize(aos.get_goal_image()) for aos in agent_observed_state_list] # goal_image_batch = cuda_var(torch.from_numpy(np.array(goal_image_seqs)).float()) # # prev_actions_raw = [aos.get_previous_action() # for aos in agent_observed_state_list] # prev_actions = [self.none_action if a is None else a # for a in prev_actions_raw] # prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions))) # # probs_batch, new_model_state, image_emb_seq = self.final_module(image_batch, image_seq_lens_batch, # goal_image_batch, prev_actions_batch, # mode, model_state) # return probs_batch, new_model_state, image_emb_seq def get_probs(self, agent_observed_state, model_state, mode=None): assert isinstance(agent_observed_state, AgentObservedState) agent_observed_state_list = [agent_observed_state] image_seq_lens = [1] image_seq_lens_batch = cuda_tensor( torch.from_numpy(np.array(image_seq_lens))) image_seqs = [[aos.get_last_image()] for aos in agent_observed_state_list] image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float()) goal_image_seqs = [[aos.get_goal_image()] for aos in agent_observed_state_list] goal_image_batch = cuda_var(torch.from_numpy(np.array(goal_image_seqs)).float()) prev_actions_raw = [aos.get_previous_action() for aos in agent_observed_state_list] prev_actions = [self.none_action if a is None else a for a in prev_actions_raw] prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions))) probs_batch, new_model_state, image_emb_seq = self.final_module(image_batch, image_seq_lens_batch, goal_image_batch, prev_actions_batch, mode, model_state) return probs_batch, new_model_state, image_emb_seq def action_prediction_log_prob(self, batch_input): assert self.action_prediction_module is not None, "Action prediction module not created. Check config." return self.action_prediction_module(batch_input) def predict_action_result(self, batch_image_feature, action_batch): assert self.temporal_autoencoder_module is not None, "Temporal action module not created. Check config." return self.temporal_autoencoder_module(batch_image_feature, action_batch) def get_probs_and_visible_objects(self, agent_observed_state_list, batch_image_feature): assert self.object_detection_module is not None, "Object detection module not created. Check config." landmarks_visible = [] for aos in agent_observed_state_list: x_pos, z_pos, y_angle = aos.get_position_orientation() landmark_pos_dict = aos.get_landmark_pos_dict() visible_landmarks = self.object_detection_module.get_visible_landmark_r_theta( x_pos, z_pos, y_angle, landmark_pos_dict, self.landmark_names) landmarks_visible.append(visible_landmarks) # shape is BATCH_SIZE x 63 x 2 probs_batch = self.object_detection_module(batch_image_feature) # landmarks_visible is list of length BATCH_SIZE, each item is a set containing landmark indices return probs_batch, landmarks_visible def load_resnet_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) def load_saved_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) image_recurrence_module_path = os.path.join( load_dir, "image_recurrence_module_state.bin") self.image_recurrence_module.load_state_dict( torch_load(image_recurrence_module_path)) action_module_path = os.path.join(load_dir, "action_module_state.bin") self.action_module.load_state_dict(torch_load(action_module_path)) final_module_path = os.path.join(load_dir, "final_module_state.bin") self.final_module.load_state_dict(torch_load(final_module_path)) if self.action_prediction_module is not None: auxiliary_action_prediction_path = os.path.join(load_dir, "auxiliary_action_prediction.bin") self.action_prediction_module.load_state_dict(torch_load(auxiliary_action_prediction_path)) if self.temporal_autoencoder_module is not None: auxiliary_temporal_autoencoder_path = os.path.join(load_dir, "auxiliary_temporal_autoencoder.bin") self.temporal_autoencoder_module.load_state_dict(torch_load(auxiliary_temporal_autoencoder_path)) if self.object_detection_module is not None: auxiliary_object_detection_path = os.path.join(load_dir, "auxiliary_object_detection.bin") self.object_detection_module.load_state_dict(torch_load(auxiliary_object_detection_path)) def save_model(self, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) # save state file for image nn image_module_path = os.path.join(save_dir, "image_module_state.bin") torch.save(self.image_module.state_dict(), image_module_path) # save state file for image recurrence nn image_recurrence_module_path = os.path.join( save_dir, "image_recurrence_module_state.bin") torch.save(self.image_recurrence_module.state_dict(), image_recurrence_module_path) # save state file for action emb action_module_path = os.path.join(save_dir, "action_module_state.bin") torch.save(self.action_module.state_dict(), action_module_path) # save state file for final nn final_module_path = os.path.join(save_dir, "final_module_state.bin") torch.save(self.final_module.state_dict(), final_module_path) # save the auxiliary models if self.action_prediction_module is not None: auxiliary_action_prediction_path = os.path.join(save_dir, "auxiliary_action_prediction.bin") torch.save(self.action_prediction_module.state_dict(), auxiliary_action_prediction_path) if self.temporal_autoencoder_module is not None: auxiliary_temporal_autoencoder_path = os.path.join(save_dir, "auxiliary_temporal_autoencoder.bin") torch.save(self.temporal_autoencoder_module.state_dict(), auxiliary_temporal_autoencoder_path) if self.object_detection_module is not None: auxiliary_object_detection_path = os.path.join(save_dir, "auxiliary_object_detection.bin") torch.save(self.object_detection_module.state_dict(), auxiliary_object_detection_path) def get_parameters(self): parameters = list(self.image_module.parameters()) parameters += list(self.image_recurrence_module.parameters()) parameters += list(self.action_module.parameters()) parameters += list(self.final_module.parameters()) if self.action_prediction_module is not None: parameters += list(self.action_prediction_module.parameters()) if self.temporal_autoencoder_module is not None: parameters += list(self.temporal_autoencoder_module.parameters()) if self.object_detection_module is not None: parameters += list(self.object_detection_module.parameters()) return parameters
class IncrementalModelRecurrentImplicitFactorizationResnet( AbstractIncrementalModel): def __init__(self, config, constants): AbstractIncrementalModel.__init__(self, config, constants) self.none_action = config["num_actions"] self.image_module = ImageResnetModule( image_emb_size=constants["image_emb_dim"], input_num_channels=3, image_height=config["image_height"], image_width=config["image_width"], using_recurrence=True) self.image_recurrence_module = IncrementalRecurrenceSimpleModule( input_emb_dim=constants["image_emb_dim"], output_emb_dim=constants["image_emb_dim"]) if config["use_pointer_model"]: raise AssertionError("Not implemented") # self.text_module = TextPointerModule( # emb_dim=constants["word_emb_dim"], # hidden_dim=constants["lstm_emb_dim"], # vocab_size=config["vocab_size"]) else: self.text_module = TextImplicitFactorizationModule( emb_dim=constants["word_emb_dim"], hidden_dim=constants["lstm_emb_dim"], vocab_size=config["vocab_size"], num_factors=2, factors_vocabulary_size=60, factors_embedding_size=250) self.action_module = ActionSimpleModule( num_actions=config["num_actions"], action_emb_size=constants["action_emb_dim"]) if config["use_pointer_model"]: total_emb_size = (constants["image_emb_dim"] + 4 * constants["lstm_emb_dim"] + constants["action_emb_dim"]) else: total_emb_size = (constants["image_emb_dim"] + 2 * 250 + constants["action_emb_dim"]) final_module = IncrementalMultimodalRecurrentSimpleModule( image_module=self.image_module, image_recurrence_module=self.image_recurrence_module, text_module=self.text_module, action_module=self.action_module, total_emb_size=total_emb_size, num_actions=config["num_actions"]) self.final_module = final_module if torch.cuda.is_available(): self.image_module.cuda() self.image_recurrence_module.cuda() self.text_module.cuda() self.action_module.cuda() self.final_module.cuda() def get_probs_batch(self, agent_observed_state_list, mode=None): for aos in agent_observed_state_list: assert isinstance(aos, AgentObservedState) # print "batch size:", len(agent_observed_state_list) # sort list by instruction length agent_observed_state_list = sorted( agent_observed_state_list, key=lambda aos_: len(aos_.get_instruction()), reverse=True) image_seq_lens = [ aos.get_num_images() for aos in agent_observed_state_list ] image_seq_lens_batch = cuda_tensor( torch.from_numpy(np.array(image_seq_lens))) max_len = max(image_seq_lens) image_seqs = [ aos.get_image()[:max_len] for aos in agent_observed_state_list ] image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float()) instructions = [ aos.get_instruction() for aos in agent_observed_state_list ] read_pointers = [ aos.get_read_pointers() for aos in agent_observed_state_list ] instructions_batch = (instructions, read_pointers) prev_actions_raw = [ aos.get_previous_action() for aos in agent_observed_state_list ] prev_actions = [ self.none_action if a is None else a for a in prev_actions_raw ] prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions))) probs_batch, _ = self.final_module(image_batch, image_seq_lens_batch, instructions_batch, prev_actions_batch, mode, model_state=None) return probs_batch def get_probs(self, agent_observed_state, model_state, mode=None): assert isinstance(agent_observed_state, AgentObservedState) agent_observed_state_list = [agent_observed_state] image_seq_lens = [1] image_seq_lens_batch = cuda_tensor( torch.from_numpy(np.array(image_seq_lens))) # max_len = max(image_seq_lens) # image_seqs = [aos.get_image()[:max_len] # for aos in agent_observed_state_list] image_seqs = [[aos.get_last_image()] for aos in agent_observed_state_list] image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float()) instructions = [ aos.get_instruction() for aos in agent_observed_state_list ] read_pointers = [ aos.get_read_pointers() for aos in agent_observed_state_list ] instructions_batch = (instructions, read_pointers) prev_actions_raw = [ aos.get_previous_action() for aos in agent_observed_state_list ] prev_actions = [ self.none_action if a is None else a for a in prev_actions_raw ] prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions))) probs_batch, new_model_state, image_emb_seq = self.final_module( image_batch, image_seq_lens_batch, instructions_batch, prev_actions_batch, mode, model_state) return probs_batch, new_model_state, image_emb_seq def get_recent_factorization_entropy(self): return self.text_module.mean_factory_entropy def load_saved_model(self, load_dir): if torch.cuda.is_available(): torch_load = torch.load else: torch_load = lambda f_: torch.load(f_, map_location=lambda s_, l_: s_) image_module_path = os.path.join(load_dir, "image_module_state.bin") self.image_module.load_state_dict(torch_load(image_module_path)) image_recurrence_module_path = os.path.join( load_dir, "image_recurrence_module_state.bin") self.image_recurrence_module.load_state_dict( torch_load(image_recurrence_module_path)) text_module_path = os.path.join(load_dir, "text_module_state.bin") self.text_module.load_state_dict(torch_load(text_module_path)) action_module_path = os.path.join(load_dir, "action_module_state.bin") self.action_module.load_state_dict(torch_load(action_module_path)) final_module_path = os.path.join(load_dir, "final_module_state.bin") self.final_module.load_state_dict(torch_load(final_module_path)) def save_model(self, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) # save state file for image nn image_module_path = os.path.join(save_dir, "image_module_state.bin") torch.save(self.image_module.state_dict(), image_module_path) # save state file for image recurrence nn image_recurrence_module_path = os.path.join( save_dir, "image_recurrence_module_state.bin") torch.save(self.image_recurrence_module.state_dict(), image_recurrence_module_path) # save state file for text nn text_module_path = os.path.join(save_dir, "text_module_state.bin") torch.save(self.text_module.state_dict(), text_module_path) # save state file for action emb action_module_path = os.path.join(save_dir, "action_module_state.bin") torch.save(self.action_module.state_dict(), action_module_path) # save state file for final nn final_module_path = os.path.join(save_dir, "final_module_state.bin") torch.save(self.final_module.state_dict(), final_module_path) def get_parameters(self): parameters = list(self.image_module.parameters()) parameters += list(self.image_recurrence_module.parameters()) parameters += list(self.text_module.parameters()) parameters += list(self.action_module.parameters()) parameters += list(self.final_module.parameters()) return parameters