def train_supervised(): initialize_experiment() setup = get_current_parameters()["Setup"] supervised_params = get_current_parameters()["Supervised"] num_epochs = supervised_params["num_epochs"] model, model_loaded = load_model() print("Loading data") train_envs, dev_envs, test_envs = get_all_env_id_lists(max_envs=setup["max_envs"]) if "split_train_data" in supervised_params and supervised_params["split_train_data"]: split_name = supervised_params["train_data_split"] split = load_env_split()[split_name] train_envs = [env_id for env_id in train_envs if env_id in split] print("Using " + str(len(train_envs)) + " envs from dataset split: " + split_name) filename = "supervised_" + setup["model"] + "_" + setup["run_name"] start_filename = "tmp/" + filename + "_epoch_" + str(supervised_params["start_epoch"]) if supervised_params["start_epoch"] > 0: if file_exists(start_filename): load_pytorch_model(model, start_filename) else: print("Couldn't continue training. Model file doesn't exist at:") print(start_filename) exit(-1) if setup["restore_weights_name"]: restore_pretrained_weights(model, setup["restore_weights_name"], setup["fix_restored_weights"]) trainer = Trainer(model, epoch=supervised_params["start_epoch"], name=setup["model"], run_name=setup["run_name"]) print("Beginning training...") best_test_loss = 1000 for epoch in range(num_epochs): train_loss = trainer.train_epoch(train_data=None, train_envs=train_envs, eval=False) trainer.model.correct_goals = 0 trainer.model.total_goals = 0 test_loss = trainer.train_epoch(train_data=None, train_envs=dev_envs, eval=True) print("GOALS: ", trainer.model.correct_goals, trainer.model.total_goals) if test_loss < best_test_loss: best_test_loss = test_loss save_pytorch_model(trainer.model, filename) print("Saved model in:", filename) print ("Epoch", epoch, "train_loss:", train_loss, "test_loss:", test_loss) save_pytorch_model(trainer.model, "tmp/" + filename + "_epoch_" + str(epoch)) if hasattr(trainer.model, "save"): trainer.model.save(epoch) save_pretrained_weights(trainer.model, setup["run_name"])
def train_rl(): initialize_experiment() setup = get_current_parameters()["Setup"] params = get_current_parameters()["RL"] print("Loading data") train_envs, dev_envs, test_envs = get_restricted_env_id_lists() filename = "rl_" + setup["model"] + "_" + setup["run_name"] trainer = TrainerRL(params=dict_merge(setup, params)) for start_epoch in range(10000): epfname = epoch_filename(filename, start_epoch) path = os.path.join(get_model_dir(), str(epfname) + ".pytorch") if not os.path.exists(path): break if start_epoch > 0: print(f"CONTINUING RL TRAINING FROM EPOCH: {start_epoch}") load_pytorch_model(trainer.full_model, epoch_filename(filename, start_epoch - 1)) trainer.set_start_epoch(start_epoch) print("Beginning training...") best_dev_reward = -1e+10 for epoch in range(start_epoch, 10000): train_reward, metrics = trainer.train_epoch(eval=False, envs="train") # TODO: Test on just a few dev environments # TODO: Take most likely or mean action when testing dev_reward, metrics = trainer.train_epoch(eval=True, envs="dev") #dev_reward, metrics = trainer.train_epoch(eval=True, envs="dev") dev_reward = 0 #if dev_reward >= best_dev_reward: # best_dev_reward = dev_reward # save_pytorch_model(trainer.full_model, filename) # print("Saved model in:", filename) print("Epoch", epoch, "train reward:", train_reward, "dev reward:", dev_reward) save_pytorch_model(trainer.full_model, epoch_filename(filename, epoch)) if hasattr(trainer.full_model, "save"): trainer.full_model.save(epoch)
def load_img_feature_weights(self): if self.params.get("load_feature_net"): filename = self.params.get("feature_net_filename") weights = load_pytorch_model(None, filename) prefix = self.params.get("feature_net_tensor_name") if prefix: weights = find_state_subdict(weights, prefix) # TODO: This breaks OOP conventions self.img_to_features_w.img_to_features.load_state_dict(weights) print( f"Loaded pretrained weights from file {filename} with prefix {prefix}" )
def train_supervised_worker(rl_process_conn): P.initialize_experiment() setup = P.get_current_parameters()["Setup"] rlsup = P.get_current_parameters()["RLSUP"] setup["trajectory_length"] = setup["sup_trajectory_length"] run_name = setup["run_name"] supervised_params = P.get_current_parameters()["Supervised"] num_epochs = supervised_params["num_epochs"] sup_device = rlsup.get("sup_device", "cuda:1") model_oracle_critic = None print("SUPP: Loading data") train_envs, dev_envs, test_envs = get_restricted_env_id_lists() # Load the starter model and save it at epoch 0 # Supervised worker to use GPU 1, RL will use GPU 0. Simulators run on GPU 2 model_sim = load_model(setup["sup_model"], setup["sim_model_file"], domain="sim")[0].to(sup_device) model_real = load_model(setup["sup_model"], setup["real_model_file"], domain="real")[0].to(sup_device) model_critic = load_model(setup["sup_critic_model"], setup["critic_model_file"])[0].to(sup_device) # ---------------------------------------------------------------------------------------------------------------- print("SUPP: Initializing trainer") rlsup_params = P.get_current_parameters()["RLSUP"] sim_seed_dataset = rlsup_params.get("sim_seed_dataset") # TODO: Figure if 6000 or 7000 here trainer = TrainerBidomainBidata(model_real, model_sim, model_critic, model_oracle_critic, epoch=0) train_envs_common = [e for e in train_envs if 6000 <= e < 7000] train_envs_sim = [e for e in train_envs if e < 7000] dev_envs_common = [e for e in dev_envs if 6000 <= e < 7000] dev_envs_sim = [e for e in dev_envs if e < 7000] sim_datasets = [rl_dataset_name(run_name)] real_datasets = ["real"] trainer.set_dataset_names(sim_datasets=sim_datasets, real_datasets=real_datasets) # ---------------------------------------------------------------------------------------------------------------- for start_sup_epoch in range(10000): epfname = epoch_sup_filename(run_name, start_sup_epoch, model="stage1", domain="sim") path = os.path.join(get_model_dir(), str(epfname) + ".pytorch") if not os.path.exists(path): break if start_sup_epoch > 0: print(f"SUPP: CONTINUING SUP TRAINING FROM EPOCH: {start_sup_epoch}") load_pytorch_model( model_real, epoch_sup_filename(run_name, start_sup_epoch - 1, model="stage1", domain="real")) load_pytorch_model( model_sim, epoch_sup_filename(run_name, start_sup_epoch - 1, model="stage1", domain="sim")) load_pytorch_model( model_critic, epoch_sup_filename(run_name, start_sup_epoch - 1, model="critic", domain="critic")) trainer.set_start_epoch(start_sup_epoch) # ---------------------------------------------------------------------------------------------------------------- print("SUPP: Beginning training...") for epoch in range(start_sup_epoch, num_epochs): # Tell the RL process that a new Stage 1 model is ready for loading print("SUPP: Sending model to RL") model_sim.reset() rl_process_conn.send( ["stage1_model_state_dict", model_sim.state_dict()]) if DEBUG_RL: while True: sleep(1) if not sim_seed_dataset: ddir = get_dataset_dir(rl_dataset_name(run_name)) os.makedirs(ddir, exist_ok=True) while len(os.listdir(ddir)) < 20: print("SUPP: Waiting for rollouts to appear") sleep(3) print("SUPP: Beginning Epoch") train_loss = trainer.train_epoch(env_list_common=train_envs_common, env_list_sim=train_envs_sim, eval=False) test_loss = trainer.train_epoch(env_list_common=dev_envs_common, env_list_sim=dev_envs_sim, eval=True) print("SUPP: Epoch", epoch, "train_loss:", train_loss, "test_loss:", test_loss) save_pytorch_model( model_real, epoch_sup_filename(run_name, epoch, model="stage1", domain="real")) save_pytorch_model( model_sim, epoch_sup_filename(run_name, epoch, model="stage1", domain="sim")) save_pytorch_model( model_critic, epoch_sup_filename(run_name, epoch, model="critic", domain="critic"))
def train_rl_worker(sup_process_conn): P.initialize_experiment() setup = P.get_current_parameters()["Setup"] setup["trajectory_length"] = setup["rl_trajectory_length"] run_name = setup["run_name"] rlsup = P.get_current_parameters()["RLSUP"] params = P.get_current_parameters()["RL"] num_rl_epochs = params["num_epochs"] # These need to be distinguished between supervised and RL because supervised trains on ALL envs, RL only on 6000-7000 setup["env_range_start"] = setup["rl_env_range_start"] setup["env_range_end"] = setup["rl_env_range_end"] rl_device = rlsup.get("rl_device", "cuda:0") trainer = TrainerRL(params=dict_merge(setup, params), save_rollouts_to_dataset=rl_dataset_name(run_name), device=rl_device) # ------------------------------------------------------------------------------------- # TODO: Continue (including figure out how to initialize Supervised Stage 1 real/sim/critic and RL Stage 2 policy start_rl_epoch = 0 for start_rl_epoch in range(num_rl_epochs): epfname = epoch_rl_filename(run_name, start_rl_epoch, model="full") path = os.path.join(get_model_dir(), str(epfname) + ".pytorch") if not os.path.exists(path): break if start_rl_epoch > 0: print(f"RLP: CONTINUING RL TRAINING FROM EPOCH: {start_rl_epoch}") load_pytorch_model( trainer.full_model, epoch_rl_filename(run_name, start_rl_epoch - 1, model="full")) trainer.set_start_epoch(start_rl_epoch) # Wait for supervised process to send it's model sleep(2) # ------------------------------------------------------------------------------------- print("RLP: Beginning training...") for rl_epoch in range(start_rl_epoch, num_rl_epochs): # Get the latest Stage 1 model. Halt on the first epoch so that we can actually initialize the Stage 1 new_stage1_model_state_dict = receive_stage1_state( sup_process_conn, halt=(rl_epoch == start_rl_epoch)) if new_stage1_model_state_dict: print(f"RLP: Re-loading latest Stage 1 model") trainer.reload_stage1(new_stage1_model_state_dict) train_reward, metrics = trainer.train_epoch(epoch_num=rl_epoch, eval=False, envs="train") dev_reward, metrics = trainer.train_epoch(epoch_num=rl_epoch, eval=True, envs="dev") print("RLP: RL Epoch", rl_epoch, "train reward:", train_reward, "dev reward:", dev_reward) save_pytorch_model(trainer.full_model, epoch_rl_filename(run_name, rl_epoch, model="full")) save_pytorch_model( trainer.full_model.stage1_visitation_prediction, epoch_rl_filename(run_name, rl_epoch, model="stage1")) save_pytorch_model( trainer.full_model.stage2_action_generation, epoch_rl_filename(run_name, rl_epoch, model="stage2"))
def train_supervised(): initialize_experiment() setup = get_current_parameters()["Setup"] supervised_params = get_current_parameters()["Supervised"] num_epochs = supervised_params["num_epochs"] model, model_loaded = load_model() # import pdb; pdb.set_trace() # import pickle # with open('/storage/dxsun/model_input.pickle', 'rb') as f: data = pickle.load(f) # g = model(data['images'], data['states'], data['instructions'], data['instr_lengths'], data['has_obs'], data['plan'], data['save_maps_only'], data['pos_enc'], data['noisy_poses'], data['start_poses'], data['firstseg']) print("model:", model) print("model type:", type(model)) print("Loading data") # import pdb;pdb.set_trace() train_envs, dev_envs, test_envs = get_all_env_id_lists( max_envs=setup["max_envs"]) if "split_train_data" in supervised_params and supervised_params[ "split_train_data"]: split_name = supervised_params["train_data_split"] split = load_env_split()[split_name] train_envs = [env_id for env_id in train_envs if env_id in split] print("Using " + str(len(train_envs)) + " envs from dataset split: " + split_name) filename = "supervised_" + setup["model"] + "_" + setup["run_name"] # Code looks weird here because load_pytorch_model adds ".pytorch" to end of path, but # file_exists doesn't model_path = "tmp/" + filename + "_epoch_" + str( supervised_params["start_epoch"]) model_path_with_extension = model_path + ".pytorch" print("model path:", model_path_with_extension) if supervised_params["start_epoch"] > 0: if file_exists(model_path_with_extension): print("THE FILE EXISTS code1") load_pytorch_model(model, model_path) else: print("Couldn't continue training. Model file doesn't exist at:") print(model_path_with_extension) exit(-1) # import pdb;pdb.set_trace() ## If you just want to use the pretrained model # load_pytorch_model(model, "supervised_pvn_stage1_train_corl_pvn_stage1") # all_train_data, all_test_data = data_io.train_data.load_supervised_data(max_envs=100) if setup["restore_weights_name"]: restore_pretrained_weights(model, setup["restore_weights_name"], setup["fix_restored_weights"]) # Add a tensorboard logger to the model and trainer tensorboard_dir = get_current_parameters( )['Environment']['tensorboard_dir'] logger = Logger(tensorboard_dir) model.logger = logger if hasattr(model, "goal_good_criterion"): print("gave logger to goal evaluator") model.goal_good_criterion.logger = logger trainer = Trainer(model, epoch=supervised_params["start_epoch"], name=setup["model"], run_name=setup["run_name"]) trainer.logger = logger # import pdb;pdb.set_trace() print("Beginning training...") best_test_loss = 1000 continue_epoch = supervised_params["start_epoch"] + 1 if supervised_params[ "start_epoch"] > 0 else 0 rng = range(0, num_epochs) print("filename:", filename) import pdb pdb.set_trace() for epoch in rng: # import pdb;pdb.set_trace() train_loss = trainer.train_epoch(train_data=None, train_envs=train_envs, eval=False) # train_loss = trainer.train_epoch(train_data=all_train_data, train_envs=train_envs, eval=False) trainer.model.correct_goals = 0 trainer.model.total_goals = 0 test_loss = trainer.train_epoch(train_data=None, train_envs=dev_envs, eval=True) print("GOALS: ", trainer.model.correct_goals, trainer.model.total_goals) if test_loss < best_test_loss: best_test_loss = test_loss save_pytorch_model(trainer.model, filename) print("Saved model in:", filename) print("Epoch", epoch, "train_loss:", train_loss, "test_loss:", test_loss) save_pytorch_model(trainer.model, "tmp/" + filename + "_epoch_" + str(epoch)) if hasattr(trainer.model, "save"): trainer.model.save(epoch) save_pretrained_weights(trainer.model, setup["run_name"])
def train_dagger_simple(): # ---------------------------------------------------------------------------------------------------------------- # Load params and configure stuff P.initialize_experiment() params = P.get_current_parameters()["SimpleDagger"] setup = P.get_current_parameters()["Setup"] num_iterations = params["num_iterations"] sim_seed_dataset = params.get("sim_seed_dataset") run_name = setup["run_name"] device = params.get("device", "cuda:1") dataset_limit = params.get("dataset_size_limit_envs") seed_count = params.get("seed_count") # Trigger rebuild if necessary before going into all the threads and processes _ = get_restricted_env_id_lists(full=True) # Initialize the dataset if sim_seed_dataset: copy_seed_dataset(from_dataset=sim_seed_dataset, to_dataset=dagger_dataset_name(run_name), seed_count=seed_count or dataset_limit) gap = 0 else: # TODO: Refactor this into a prompt function data_path = get_dataset_dir(dagger_dataset_name(run_name)) if os.path.exists(data_path): print("DATASET EXISTS! Continue where left off?") c = input(" (y/n) >>> ") if c != "y": raise ValueError( f"Not continuing: Dataset {data_path} exists. Delete it if you like and try again" ) else: os.makedirs(data_path, exist_ok=True) gap = dataset_limit - len(os.listdir(data_path)) print("SUPP: Loading data") train_envs, dev_envs, test_envs = get_restricted_env_id_lists() # ---------------------------------------------------------------------------------------------------------------- # Load / initialize model model = load_model(setup["model"], setup["model_file"], domain="sim")[0].to(device) oracle = load_model("oracle")[0] # ---------------------------------------------------------------------------------------------------------------- # Continue where we left off - load the model and set the iteration/epoch number for start_iteration in range(10000): epfname = epoch_dag_filename(run_name, start_iteration) path = os.path.join(get_model_dir(), str(epfname) + ".pytorch") if not os.path.exists(path): break if start_iteration > 0: print( f"DAG: CONTINUING DAGGER TRAINING FROM ITERATION: {start_iteration}" ) load_pytorch_model(model, epoch_dag_filename(run_name, start_iteration - 1)) # ---------------------------------------------------------------------------------------------------------------- # Intialize trainer trainer = Trainer(model, epoch=start_iteration, name=setup["model"], run_name=setup["run_name"]) trainer.set_dataset_names([dagger_dataset_name(run_name)]) # ---------------------------------------------------------------------------------------------------------------- # Initialize policy roller roller = SimpleParallelPolicyRoller( num_workers=params["num_workers"], device=params["device"], policy_name=setup["model"], policy_file=setup["model_file"], oracle=oracle, dataset_save_name=dagger_dataset_name(run_name), no_reward=True) rollout_sampler = RolloutSampler(roller) # ---------------------------------------------------------------------------------------------------------------- # Train DAgger - loop over iteartions, in each, prune, rollout and train an epoch print("SUPP: Beginning training...") for iteration in range(start_iteration, num_iterations): print(f"DAG: Starting iteration {iteration}") # Remove extra rollouts to keep within DAggerFM limit prune_dataset(run_name, dataset_limit) # Rollout and collect more data for training and evaluation policy_state = model.get_policy_state() rollout_sampler.sample_n_rollouts( n=gap if iteration == 0 else params["train_envs_per_iteration"], policy_state=policy_state, sample=False, envs="train", dagger_beta=dagger_beta(params, iteration)) eval_rollouts = rollout_sampler.sample_n_rollouts( n=params["eval_envs_per_iteration"], policy_state=policy_state, sample=False, envs="dev", dagger_beta=0) # Kill airsim instances so that they don't take up GPU memory and in general slow things down during training roller.kill_airsim() # Evaluate success / metrics and save to tensorboard if setup["eval_nl"]: evaler = DataEvalNL(run_name, entire_trajectory=False, save_images=False) evaler.evaluate_dataset(eval_rollouts) results = evaler.get_results() print("Results:", results) evaler.write_summaries(setup["run_name"], "dagger_eval", iteration) # Do one epoch of supervised training print("SUPP: Beginning Epoch") train_loss = trainer.train_epoch(train_envs=train_envs, eval=False) #test_loss = trainer.train_epoch(env_list_common=dev_envs_common, env_list_sim=dev_envs_sim, eval=True) # Save the model to file print("SUPP: Epoch", iteration, "train_loss:", train_loss) save_pytorch_model(model, epoch_dag_filename(run_name, iteration))
def load_model(model_file_override=None, real=False, model_name_override=False): setup = P.get_current_parameters()["Setup"] model_name = model_name_override or setup["model"] model_file = model_file_override or setup["model_file"] or None cuda = setup["cuda"] run_name = setup["run_name"] model = None pytorch_model = False # ----------------------------------------------------------------------------------------------------------------- # Oracles / baselines that ignore images # ----------------------------------------------------------------------------------------------------------------- if model_name == "oracle": rollout_params = get_current_parameters()["Rollout"] if rollout_params["oracle_type"] == "SimpleCarrotPlanner": model = SimpleCarrotPlanner() print("Using simple carrot planner") elif rollout_params["oracle_type"] == "BasicCarrotPlanner": model = BasicCarrotPlanner() print("Using basic carrot planner") elif rollout_params["oracle_type"] == "FancyCarrotPlanner": model = FancyCarrotPlanner() print("Using fancy carrot planner") else: print("UNKNOWN ORACLE: ", rollout_params["OracleType"]) exit(-1) elif model_name == "baseline_straight": model = BaselineStraight() elif model_name == "baseline_stop": model = BaselineStop() # ----------------------------------------------------------------------------------------------------------------- # FASTER RSS 2018 Resubmission Model # ----------------------------------------------------------------------------------------------------------------- elif model_name == "gsmn": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True) pytorch_model = True elif model_name == "gsmn_wo_jlang": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=False) pytorch_model = True elif model_name == "gsmn_wo_jgnd": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=True, aux_grounding_map=False, aux_goal_map=True, aux_lang=True) pytorch_model = True elif model_name == "gsmn_wo_jclass": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=False, aux_grounding_map=True, aux_goal_map=True, aux_lang=True) pytorch_model = True elif model_name == "gsmn_wo_jgoal": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=True, aux_grounding_map=True, aux_goal_map=False, aux_lang=True) pytorch_model = True elif model_name == "gsmn_w_posnoise": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True, pos_noise=True, rot_noise=False) pytorch_model = True elif model_name == "gsmn_w_rotnoise": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True, pos_noise=False, rot_noise=True) pytorch_model = True elif model_name == "gsmn_w_bothnoise": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True, pos_noise=True, rot_noise=True) pytorch_model = True elif model_name == "gs_fpv": model = ModelGSFPV(run_name, aux_class_features=True, aux_grounding_features=True, aux_lang=True, recurrence=False) pytorch_model = True elif model_name == "gs_fpv_mem": model = ModelGSFPV(run_name, aux_class_features=True, aux_grounding_features=True, aux_lang=True, recurrence=True) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL 2018 Model # ----------------------------------------------------------------------------------------------------------------- elif model_name == "sm_traj_nav_ratio": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV) pytorch_model = True elif model_name == "sm_traj_nav_ratio_path": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.PVN_STAGE1_ONLY) pytorch_model = True elif model_name == "action_gtr": model = ModelTrajectoryToAction(run_name) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL 2018 Refactored # ----------------------------------------------------------------------------------------------------------------- elif model_name == "pvn_full": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV) pytorch_model = True elif model_name == "pvn_stage1": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.PVN_STAGE1_ONLY) pytorch_model = True elif model_name == "pvn_stage2": model = ModelTrajectoryToAction(run_name) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL 2018 Top-Down Full Observability Models # ----------------------------------------------------------------------------------------------------------------- elif model_name == "top_down_goal_batched": model = ModelTopDownPathGoalPredictorBatched(run_name) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL Baselines # ----------------------------------------------------------------------------------------------------------------- elif model_name == "chaplot": model = ModelChaplot(run_name) pytorch_model = True elif model_name == "misra2017": model = ModelMisra2017(run_name) pytorch_model = True model_loaded = False if pytorch_model: n_params = get_n_params(model) n_params_tr = get_n_trainable_params(model) print("Loaded PyTorch model!") print("Number of model parameters: " + str(n_params)) print("Trainable model parameters: " + str(n_params_tr)) model.init_weights() model.eval() if model_file: load_pytorch_model(model, model_file) print("Loaded previous model: ", model_file) model_loaded = True if cuda: model = model.cuda() return model, model_loaded
def load_model(model_name_override=False, model_file_override=None, domain="sim"): setup = P.get_current_parameters()["Setup"] model_name = model_name_override or setup["model"] model_file = model_file_override or setup["model_file"] or None # TODO: Move this stuff elsewhere and tidy up the model perception_model_file = setup.get("perception_model_file") or None perception_model_real = setup.get("perception_model_real") or None cuda = setup["cuda"] run_name = setup["run_name"] model = None pytorch_model = False # ----------------------------------------------------------------------------------------------------------------- # Oracles / baselines that ignore images # ----------------------------------------------------------------------------------------------------------------- if model_name == "oracle": rollout_params = get_current_parameters()["Rollout"] if rollout_params["oracle_type"] == "SimpleCarrotPlanner": model = SimpleCarrotPlanner() print("Using simple carrot planner") elif rollout_params["oracle_type"] == "BasicCarrotPlanner": model = BasicCarrotPlanner() print("Using basic carrot planner") elif rollout_params["oracle_type"] == "FancyCarrotPlanner": model = FancyCarrotPlanner() print("Using fancy carrot planner") else: print("UNKNOWN ORACLE: ", rollout_params["OracleType"]) exit(-1) elif model_name == "average": model = BaselineAverage() elif model_name == "stop": model = BaselineStop() # ----------------------------------------------------------------------------------------------------------------- # FASTER RSS 2018 Resubmission Model # ----------------------------------------------------------------------------------------------------------------- elif model_name == "gsmn": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True) pytorch_model = True elif model_name == "gsmn_wo_jlang": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=False) pytorch_model = True elif model_name == "gsmn_wo_jgnd": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=True, aux_grounding_map=False, aux_goal_map=True, aux_lang=True) pytorch_model = True elif model_name == "gsmn_wo_jclass": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=False, aux_grounding_map=True, aux_goal_map=True, aux_lang=True) pytorch_model = True elif model_name == "gsmn_wo_jgoal": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_map=True, aux_grounding_map=True, aux_goal_map=False, aux_lang=True) pytorch_model = True elif model_name == "gsmn_w_posnoise": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True, pos_noise=True, rot_noise=False) pytorch_model = True elif model_name == "gsmn_w_rotnoise": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True, pos_noise=False, rot_noise=True) pytorch_model = True elif model_name == "gsmn_w_bothnoise": model = ModelRSS(run_name, model_class=ModelRSS.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=True, pos_noise=True, rot_noise=True) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # RSS Baselines # ----------------------------------------------------------------------------------------------------------------- elif model_name == "gs_fpv": model = ModelGSFPV(run_name, aux_class_features=True, aux_grounding_features=True, aux_lang=True, recurrence=False) pytorch_model = True elif model_name == "gs_fpv_mem": model = ModelGSFPV(run_name, aux_class_features=True, aux_grounding_features=True, aux_lang=True, recurrence=True) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # RSS Model for Cage # ----------------------------------------------------------------------------------------------------------------- elif model_name == "gsmn_cage": model = ModelRSS(run_name, model_class=msrg.MODEL_RSS, aux_class_features=False, aux_grounding_features=False, aux_class_map=True, aux_grounding_map=True, aux_goal_map=True, aux_lang=False) pytorch_model = True elif model_name == "gsmn_bidomain": model = ModelGSMNBiDomain(run_name, model_instance_name=domain) pytorch_model = True elif model_name == "gsmn_critic": model = ModelGsmnCritic(run_name) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL 2018 Model # ----------------------------------------------------------------------------------------------------------------- elif model_name == "sm_traj_nav_ratio": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV) pytorch_model = True elif model_name == "sm_traj_nav_ratio_path": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.PVN_STAGE1_ONLY) pytorch_model = True elif model_name == "action_gtr": model = ModelTrajectoryToAction(run_name) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL 2018 Refactored # ----------------------------------------------------------------------------------------------------------------- elif model_name == "pvn_full": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV) pytorch_model = True elif model_name == "pvn_stage1": model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.PVN_STAGE1_ONLY) pytorch_model = True elif model_name == "pvn_stage2": model = ModelTrajectoryToAction(run_name) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL 2018 Top-Down Full Observability Models # ----------------------------------------------------------------------------------------------------------------- elif model_name == "top_down_goal_batched": model = ModelTopDownPathGoalPredictorBatched(run_name) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- # CoRL Model for cage (bidomain) # ----------------------------------------------------------------------------------------------------------------- elif model_name == "pvn_original_stage1_bidomain": model = PVN_Stage1_Bidomain_Original(run_name, domain=domain) pytorch_model = True elif model_name == "pvn_stage1_bidomain": model = PVN_Stage1_Bidomain(run_name, domain=domain) pytorch_model = True elif model_name == "pvn_stage2_bidomain": model = PVN_Stage2_Bidomain(run_name, model_instance_name=domain) pytorch_model = True elif model_name == "pvn_stage2_actor_critic": model = PVN_Stage2_ActorCritic(run_name, model_instance_name=domain) pytorch_model = True elif model_name == "pvn_stage1_critic": model = PVN_Stage1_Critic(run_name) pytorch_model = True elif model_name == "pvn_stage1_critic_big": model = PVN_Stage1_Critic_Big(run_name) pytorch_model = True elif model_name == "pvn_full_bidomain": model = PVN_Wrapper_Bidomain(run_name, model_instance_name=domain, oracle_stage1=False) pytorch_model = True elif model_name == "pvn_full_bidomain_ground_truth": model = PVN_Wrapper_Bidomain(run_name, model_instance_name=domain, oracle_stage1=True) pytorch_model = True # ----------------------------------------------------------------------------------------------------------------- model_loaded = False if pytorch_model: n_params = get_n_params(model) n_params_tr = get_n_trainable_params(model) print("Loaded PyTorch model!") print("Number of model parameters: " + str(n_params)) print("Trainable model parameters: " + str(n_params_tr)) model.init_weights() model.eval() if model_file: load_pytorch_model(model, model_file, pytorch3to4=True) print("Loaded previous model: ", model_file) model_loaded = True if cuda: model = model.cuda() return model, model_loaded