def __init__( self, model, state=None, epoch=0, name="", run_name="", ): _, _, _, corpus = get_all_instructions() self.token2word, self.word2token = get_word_to_token_map(corpus) self.params = get_current_parameters()["Training"] self.batch_size = self.params['batch_size'] self.weight_decay = self.params['weight_decay'] self.optimizer = self.params['optimizer'] self.num_loaders = self.params['num_loaders'] self.lr = self.params['lr'] self.name = name self.dataset_names = None n_params = get_n_params(model) n_params_tr = get_n_trainable_params(model) print("Training Model:") print("Number of model parameters: " + str(n_params)) print("Trainable model parameters: " + str(n_params_tr)) self.model = model self.run_name = run_name if self.optimizer == "adam": self.optim = optim.Adam(self.get_model_parameters(self.model), self.lr, weight_decay=self.weight_decay) elif self.optimizer == "sgd": self.optim = optim.SGD(self.get_model_parameters(self.model), self.lr, weight_decay=self.weight_decay, momentum=0.9) self.train_epoch_num = epoch self.train_segment = 0 self.test_epoch_num = epoch self.test_segment = 0 self.set_state(state) self.batch_num = 0
def __init__(self, model_real, model_sim, model_critic, model_oracle_critic=None, state=None, epoch=0): _, _, _, corpus = get_all_instructions() self.token2word, self.word2token = get_word_to_token_map(corpus) self.params = get_current_parameters()["Training"] self.run_name = get_current_parameters()["Setup"]["run_name"] self.batch_size = self.params['batch_size'] self.iterations_per_epoch = self.params.get("iterations_per_epoch", None) self.weight_decay = self.params['weight_decay'] self.optimizer = self.params['optimizer'] self.critic_loaders = self.params['critic_loaders'] self.model_common_loaders = self.params['model_common_loaders'] self.model_sim_loaders = self.params['model_sim_loaders'] self.lr = self.params['lr'] self.critic_steps = self.params['critic_steps'] self.model_steps = self.params['model_steps'] self.critic_batch_size = self.params["critic_batch_size"] self.model_batch_size = self.params["model_batch_size"] self.disable_wloss = self.params["disable_wloss"] self.sim_steps_per_real_step = self.params.get( "sim_steps_per_real_step", 1) self.real_dataset_names = self.params.get("real_dataset_names") self.sim_dataset_names = self.params.get("sim_dataset_names") self.bidata = self.params.get("bidata", False) self.sim_steps_per_common_step = self.params.get( "sim_steps_per_common_step", 1) n_params_real = get_n_params(model_real) n_params_real_tr = get_n_trainable_params(model_real) n_params_sim = get_n_params(model_sim) n_params_sim_tr = get_n_trainable_params(model_sim) n_params_c = get_n_params(model_critic) n_params_c_tr = get_n_params(model_critic) print("Training Model:") print("Real # model parameters: " + str(n_params_real)) print("Real # trainable parameters: " + str(n_params_real_tr)) print("Sim # model parameters: " + str(n_params_sim)) print("Sim # trainable parameters: " + str(n_params_sim_tr)) print("Critic # model parameters: " + str(n_params_c)) print("Critic # trainable parameters: " + str(n_params_c_tr)) # Share those modules that are to be shared between real and sim models if not self.params.get("disable_domain_weight_sharing"): print("Sharing weights between sim and real modules") model_real.steal_cross_domain_modules(model_sim) else: print("NOT Sharing weights between sim and real modules") self.model_real = model_real self.model_sim = model_sim self.model_critic = model_critic self.model_oracle_critic = model_oracle_critic if self.model_oracle_critic: print("Using oracle critic") if self.optimizer == "adam": Optim = optim.Adam elif self.optimizer == "sgd": Optim = optim.SGD else: raise ValueError(f"Unsuppored optimizer {self.optimizer}") self.optim_models = Optim(self.model_real.both_domain_parameters( self.model_sim), self.lr, weight_decay=self.weight_decay) self.optim_critic = Optim(self.critic_parameters(), self.lr, weight_decay=self.weight_decay) self.train_epoch_num = epoch self.train_segment = 0 self.test_epoch_num = epoch self.test_segment = 0 self.set_state(state)
def load_model(model_file_override=None, real=False, model_name_override=False): setup = P.get_current_parameters()["Setup"] model_name = model_name_override or setup["model"] model_file = model_file_override or setup["model_file"] or None cuda = setup["cuda"] run_name = setup["run_name"] model = None pytorch_model = False # ----------------------------------------------------------------------------------------------------------------- # Oracles / baselines that ignore images # ----------------------------------------------------------------------------------------------------------------- if model_name == "oracle": rollout_params = get_current_parameters()["Rollout"] if rollout_params["oracle_type"] == "SimpleCarrotPlanner": model = SimpleCarrotPlanner() print("Using simple carrot planner") elif rollout_params["oracle_type"] == "BasicCarrotPlanner": model = BasicCarrotPlanner() print("Using basic carrot planner") elif rollout_params["oracle_type"] == "FancyCarrotPlanner": model = FancyCarrotPlanner() print("Using fancy carrot planner") else: print("UNKNOWN ORACLE: ", rollout_params["OracleType"]) exit(-1) elif model_name == "baseline_straight": model = BaselineStraight() elif model_name == "baseline_stop": model = BaselineStop() # ----------------------------------------------------------------------------------------------------------------- # FASTER RSS 2018 Resubmission Model # ----------------------------------------------------------------------------------------------------------------- elif model_name == "gsmn": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True) pytorch_model = True elif model_name == "gsmn_wo_jlang": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=False) pytorch_model = True elif model_name == "gsmn_wo_jgnd": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=True, aux_grounding_map=False, aux_goal_map=True, aux_lang=True) pytorch_model = True elif model_name == "gsmn_wo_jclass": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=False, aux_grounding_map=True, aux_goal_map=True, aux_lang=True) pytorch_model = True elif model_name == "gsmn_wo_jgoal": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=True, aux_grounding_map=True, aux_goal_map=False, aux_lang=True) pytorch_model = True elif model_name == "gsmn_w_posnoise": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True, pos_noise=True, rot_noise=False) pytorch_model = True elif model_name == "gsmn_w_rotnoise": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True, pos_noise=False, rot_noise=True) pytorch_model = True elif model_name == "gsmn_w_bothnoise": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True, pos_noise=True, rot_noise=True) pytorch_model = True elif model_name == "gs_fpv": model = ModelGSFPV(run_name, aux_class_features=True, aux_grounding_features=True, aux_lang=True, recurrence=False) pytorch_model = True elif model_name == "gs_fpv_mem": model = ModelGSFPV(run_name, aux_class_features=True, aux_grounding_features=True, aux_lang=True, recurrence=True) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL 2018 Model # ----------------------------------------------------------------------------------------------------------------- elif model_name == "sm_traj_nav_ratio": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV) pytorch_model = True elif model_name == "sm_traj_nav_ratio_path": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.PVN_STAGE1_ONLY) pytorch_model = True elif model_name == "action_gtr": model = ModelTrajectoryToAction(run_name) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL 2018 Refactored # ----------------------------------------------------------------------------------------------------------------- elif model_name == "pvn_full": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV) pytorch_model = True elif model_name == "pvn_stage1": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.PVN_STAGE1_ONLY) pytorch_model = True elif model_name == "pvn_stage2": model = ModelTrajectoryToAction(run_name) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL 2018 Top-Down Full Observability Models # ----------------------------------------------------------------------------------------------------------------- elif model_name == "top_down_goal_batched": model = ModelTopDownPathGoalPredictorBatched(run_name) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL Baselines # ----------------------------------------------------------------------------------------------------------------- elif model_name == "chaplot": model = ModelChaplot(run_name) pytorch_model = True elif model_name == "misra2017": model = ModelMisra2017(run_name) pytorch_model = True model_loaded = False if pytorch_model: n_params = get_n_params(model) n_params_tr = get_n_trainable_params(model) print("Loaded PyTorch model!") print("Number of model parameters: " + str(n_params)) print("Trainable model parameters: " + str(n_params_tr)) model.init_weights() model.eval() if model_file: load_pytorch_model(model, model_file) print("Loaded previous model: ", model_file) model_loaded = True if cuda: model = model.cuda() return model, model_loaded
def __init__(self, run_name, ignore_lang=False, class_loss=True, ground_loss=True): super(ModelTopDownPathGoalPredictor, self).__init__() self.run_name = run_name self.model_name = "top_down_path_pred_pretrain" self.writer = SummaryWriter(log_dir="runs/" + run_name) self.ignore_lang = ignore_lang self.class_loss = class_loss self.ground_loss = ground_loss # The feature net extracts the 2D feature map from the input image. # The label_pool down-sizes the ground-truth labels, which are input at the same size as the input image # The output predicted labels are the size of the feature map self.feature_net = ResNet13Light(32, down_pad=True) self.label_pool = nn.MaxPool2d(8) if self.ground_loss: self.lang_filter = MapLangSemanticFilter(sentence_embedding_size, 32, 3) self.aux_ground_linear = nn.Linear(3, 2) enable_weight_saving(self.lang_filter, "ground_filter") enable_weight_saving(self.aux_ground_linear, "ground_aux_linear") if RESNET: self.unet = ResNetConditional(sentence_embedding_size, 35, 2) else: unet_c_in = 35 if self.ground_loss else 32 unet_hc1 = 48 if self.ground_loss else 48 unet_hb1 = 24 if self.ground_loss else 24 self.unet = Unet5ContextualBneck(unet_c_in, 2, sentence_embedding_size, hc1=unet_hc1, hb1=unet_hb1, hc2=128, split_embedding=splitemb) if attention: self.sentence_embedding = SentenceEmbeddingSelfAttention( word_embedding_size, lstm_size, sentence_embedding_layers, attention_heads=attention_heads) else: self.sentence_embedding = SentenceEmbeddingSimple( word_embedding_size, sentence_embedding_size, sentence_embedding_layers) self.gather2d = Gather2D() if self.class_loss: self.aux_class_linear = nn.Linear(32, 64) enable_weight_saving(self.aux_class_linear, "class_aux_linear") print("Sentence Embedding #Params: ", get_n_params(self.sentence_embedding)) print("U-Net #Params: ", get_n_params(self.unet)) print("Class auxiliary: ", self.class_loss) print("Ground auxiliary: ", self.ground_loss) # Enable saving of pre-trained weights enable_weight_saving(self.feature_net, "feature_resnet_light") enable_weight_saving(self.unet, "unet") enable_weight_saving(self.sentence_embedding, "sentence_embedding") if NLL: #self.mask_loss = nn.BCELoss() self.mask_loss = nn.NLLLoss2d() elif BCE: self.mask_loss = nn.BCEWithLogitsLoss() elif CE: self.spatialsoftmax = SpatialSoftmax2d() self.mask_loss = CrossEntropy2d() else: self.mask_loss = nn.MSELoss() self.aux_loss = nn.CrossEntropyLoss(reduce=True, size_average=True) self.epoch_numbers = {"train": 0, "eval": 0} self.iter = nn.Parameter(torch.zeros(1), requires_grad=False) self.dropout = nn.Dropout(0.5) self.dropout2d = nn.Dropout2d(0.5) self.dropout3d = nn.Dropout3d(0.5) self.viz_images = [] self.instructions = []
def load_model(model_name_override=False, model_file_override=None, domain="sim"): setup = P.get_current_parameters()["Setup"] model_name = model_name_override or setup["model"] model_file = model_file_override or setup["model_file"] or None # TODO: Move this stuff elsewhere and tidy up the model perception_model_file = setup.get("perception_model_file") or None perception_model_real = setup.get("perception_model_real") or None cuda = setup["cuda"] run_name = setup["run_name"] model = None pytorch_model = False # ----------------------------------------------------------------------------------------------------------------- # Oracles / baselines that ignore images # ----------------------------------------------------------------------------------------------------------------- if model_name == "oracle": rollout_params = get_current_parameters()["Rollout"] if rollout_params["oracle_type"] == "SimpleCarrotPlanner": model = SimpleCarrotPlanner() print("Using simple carrot planner") elif rollout_params["oracle_type"] == "BasicCarrotPlanner": model = BasicCarrotPlanner() print("Using basic carrot planner") elif rollout_params["oracle_type"] == "FancyCarrotPlanner": model = FancyCarrotPlanner() print("Using fancy carrot planner") else: print("UNKNOWN ORACLE: ", rollout_params["OracleType"]) exit(-1) elif model_name == "average": model = BaselineAverage() elif model_name == "stop": model = BaselineStop() # ----------------------------------------------------------------------------------------------------------------- # FASTER RSS 2018 Resubmission Model # ----------------------------------------------------------------------------------------------------------------- elif model_name == "gsmn": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True) pytorch_model = True elif model_name == "gsmn_wo_jlang": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=False) pytorch_model = True elif model_name == "gsmn_wo_jgnd": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=True, aux_grounding_map=False, aux_goal_map=True, aux_lang=True) pytorch_model = True elif model_name == "gsmn_wo_jclass": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=False, aux_grounding_map=True, aux_goal_map=True, aux_lang=True) pytorch_model = True elif model_name == "gsmn_wo_jgoal": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=True, aux_grounding_map=True, aux_goal_map=False, aux_lang=True) pytorch_model = True elif model_name == "gsmn_w_posnoise": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True, pos_noise=True, rot_noise=False) pytorch_model = True elif model_name == "gsmn_w_rotnoise": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True, pos_noise=False, rot_noise=True) pytorch_model = True elif model_name == "gsmn_w_bothnoise": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True, pos_noise=True, rot_noise=True) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # RSS Baselines # ----------------------------------------------------------------------------------------------------------------- elif model_name == "gs_fpv": model = ModelGSFPV(run_name, aux_class_features=True, aux_grounding_features=True, aux_lang=True, recurrence=False) pytorch_model = True elif model_name == "gs_fpv_mem": model = ModelGSFPV(run_name, aux_class_features=True, aux_grounding_features=True, aux_lang=True, recurrence=True) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # RSS Model for Cage # ----------------------------------------------------------------------------------------------------------------- elif model_name == "gsmn_cage": model = ModelRSS(run_name, model_class=msrg.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=False) pytorch_model = True elif model_name == "gsmn_bidomain": model = ModelGSMNBiDomain(run_name, model_instance_name=domain) pytorch_model = True elif model_name == "gsmn_critic": model = ModelGsmnCritic(run_name) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL 2018 Model # ----------------------------------------------------------------------------------------------------------------- elif model_name == "sm_traj_nav_ratio": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV) pytorch_model = True elif model_name == "sm_traj_nav_ratio_path": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.PVN_STAGE1_ONLY) pytorch_model = True elif model_name == "action_gtr": model = ModelTrajectoryToAction(run_name) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL 2018 Refactored # ----------------------------------------------------------------------------------------------------------------- elif model_name == "pvn_full": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV) pytorch_model = True elif model_name == "pvn_stage1": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.PVN_STAGE1_ONLY) pytorch_model = True elif model_name == "pvn_stage2": model = ModelTrajectoryToAction(run_name) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL 2018 Top-Down Full Observability Models # ----------------------------------------------------------------------------------------------------------------- elif model_name == "top_down_goal_batched": model = ModelTopDownPathGoalPredictorBatched(run_name) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL Model for cage (bidomain) # ----------------------------------------------------------------------------------------------------------------- elif model_name == "pvn_original_stage1_bidomain": model = PVN_Stage1_Bidomain_Original(run_name, domain=domain) pytorch_model = True elif model_name == "pvn_stage1_bidomain": model = PVN_Stage1_Bidomain(run_name, domain=domain) pytorch_model = True elif model_name == "pvn_stage2_bidomain": model = PVN_Stage2_Bidomain(run_name, model_instance_name=domain) pytorch_model = True elif model_name == "pvn_stage2_actor_critic": model = PVN_Stage2_ActorCritic(run_name, model_instance_name=domain) pytorch_model = True elif model_name == "pvn_stage1_critic": model = PVN_Stage1_Critic(run_name) pytorch_model = True elif model_name == "pvn_stage1_critic_big": model = PVN_Stage1_Critic_Big(run_name) pytorch_model = True elif model_name == "pvn_full_bidomain": model = PVN_Wrapper_Bidomain(run_name, model_instance_name=domain, oracle_stage1=False) pytorch_model = True elif model_name == "pvn_full_bidomain_ground_truth": model = PVN_Wrapper_Bidomain(run_name, model_instance_name=domain, oracle_stage1=True) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- model_loaded = False if pytorch_model: n_params = get_n_params(model) n_params_tr = get_n_trainable_params(model) print("Loaded PyTorch model!") print("Number of model parameters: " + str(n_params)) print("Trainable model parameters: " + str(n_params_tr)) model.init_weights() model.eval() if model_file: load_pytorch_model(model, model_file, pytorch3to4=True) print("Loaded previous model: ", model_file) model_loaded = True if cuda: model = model.cuda() return model, model_loaded