Beispiel #1
0
def train_supervised():
    initialize_experiment()

    setup = get_current_parameters()["Setup"]
    supervised_params = get_current_parameters()["Supervised"]
    num_epochs = supervised_params["num_epochs"]

    model, model_loaded = load_model()

    print("Loading data")
    train_envs, dev_envs, test_envs = get_all_env_id_lists(max_envs=setup["max_envs"])

    if "split_train_data" in supervised_params and supervised_params["split_train_data"]:
        split_name = supervised_params["train_data_split"]
        split = load_env_split()[split_name]
        train_envs = [env_id for env_id in train_envs if env_id in split]
        print("Using " + str(len(train_envs)) + " envs from dataset split: " + split_name)

    filename = "supervised_" + setup["model"] + "_" + setup["run_name"]
    start_filename = "tmp/" + filename + "_epoch_" + str(supervised_params["start_epoch"])
    if supervised_params["start_epoch"] > 0:
        if file_exists(start_filename):
            load_pytorch_model(model, start_filename)
        else:
            print("Couldn't continue training. Model file doesn't exist at:")
            print(start_filename)
            exit(-1)

    if setup["restore_weights_name"]:
        restore_pretrained_weights(model, setup["restore_weights_name"], setup["fix_restored_weights"])

    trainer = Trainer(model, epoch=supervised_params["start_epoch"], name=setup["model"], run_name=setup["run_name"])

    print("Beginning training...")
    best_test_loss = 1000
    for epoch in range(num_epochs):
        train_loss = trainer.train_epoch(train_data=None, train_envs=train_envs, eval=False)

        trainer.model.correct_goals = 0
        trainer.model.total_goals = 0

        test_loss = trainer.train_epoch(train_data=None, train_envs=dev_envs, eval=True)

        print("GOALS: ", trainer.model.correct_goals, trainer.model.total_goals)

        if test_loss < best_test_loss:
            best_test_loss = test_loss
            save_pytorch_model(trainer.model, filename)
            print("Saved model in:", filename)
        print ("Epoch", epoch, "train_loss:", train_loss, "test_loss:", test_loss)
        save_pytorch_model(trainer.model, "tmp/" + filename + "_epoch_" + str(epoch))
        if hasattr(trainer.model, "save"):
            trainer.model.save(epoch)
        save_pretrained_weights(trainer.model, setup["run_name"])
Beispiel #2
0
def train_rl():
    initialize_experiment()

    setup = get_current_parameters()["Setup"]
    params = get_current_parameters()["RL"]

    print("Loading data")
    train_envs, dev_envs, test_envs = get_restricted_env_id_lists()

    filename = "rl_" + setup["model"] + "_" + setup["run_name"]

    trainer = TrainerRL(params=dict_merge(setup, params))

    for start_epoch in range(10000):
        epfname = epoch_filename(filename, start_epoch)
        path = os.path.join(get_model_dir(), str(epfname) + ".pytorch")
        if not os.path.exists(path):
            break

    if start_epoch > 0:
        print(f"CONTINUING RL TRAINING FROM EPOCH: {start_epoch}")
        load_pytorch_model(trainer.full_model,
                           epoch_filename(filename, start_epoch - 1))
        trainer.set_start_epoch(start_epoch)

    print("Beginning training...")
    best_dev_reward = -1e+10
    for epoch in range(start_epoch, 10000):
        train_reward, metrics = trainer.train_epoch(eval=False, envs="train")
        # TODO: Test on just a few dev environments
        # TODO: Take most likely or mean action when testing
        dev_reward, metrics = trainer.train_epoch(eval=True, envs="dev")
        #dev_reward, metrics = trainer.train_epoch(eval=True, envs="dev")
        dev_reward = 0

        #if dev_reward >= best_dev_reward:
        #    best_dev_reward = dev_reward
        #    save_pytorch_model(trainer.full_model, filename)
        #    print("Saved model in:", filename)

        print("Epoch", epoch, "train reward:", train_reward, "dev reward:",
              dev_reward)
        save_pytorch_model(trainer.full_model, epoch_filename(filename, epoch))
        if hasattr(trainer.full_model, "save"):
            trainer.full_model.save(epoch)
Beispiel #3
0
 def load_img_feature_weights(self):
     if self.params.get("load_feature_net"):
         filename = self.params.get("feature_net_filename")
         weights = load_pytorch_model(None, filename)
         prefix = self.params.get("feature_net_tensor_name")
         if prefix:
             weights = find_state_subdict(weights, prefix)
         # TODO: This breaks OOP conventions
         self.img_to_features_w.img_to_features.load_state_dict(weights)
         print(
             f"Loaded pretrained weights from file {filename} with prefix {prefix}"
         )
Beispiel #4
0
def train_supervised_worker(rl_process_conn):
    P.initialize_experiment()
    setup = P.get_current_parameters()["Setup"]
    rlsup = P.get_current_parameters()["RLSUP"]
    setup["trajectory_length"] = setup["sup_trajectory_length"]
    run_name = setup["run_name"]
    supervised_params = P.get_current_parameters()["Supervised"]
    num_epochs = supervised_params["num_epochs"]
    sup_device = rlsup.get("sup_device", "cuda:1")

    model_oracle_critic = None

    print("SUPP: Loading data")
    train_envs, dev_envs, test_envs = get_restricted_env_id_lists()

    # Load the starter model and save it at epoch 0
    # Supervised worker to use GPU 1, RL will use GPU 0. Simulators run on GPU 2
    model_sim = load_model(setup["sup_model"],
                           setup["sim_model_file"],
                           domain="sim")[0].to(sup_device)
    model_real = load_model(setup["sup_model"],
                            setup["real_model_file"],
                            domain="real")[0].to(sup_device)
    model_critic = load_model(setup["sup_critic_model"],
                              setup["critic_model_file"])[0].to(sup_device)

    # ----------------------------------------------------------------------------------------------------------------

    print("SUPP: Initializing trainer")
    rlsup_params = P.get_current_parameters()["RLSUP"]
    sim_seed_dataset = rlsup_params.get("sim_seed_dataset")

    # TODO: Figure if 6000 or 7000 here
    trainer = TrainerBidomainBidata(model_real,
                                    model_sim,
                                    model_critic,
                                    model_oracle_critic,
                                    epoch=0)
    train_envs_common = [e for e in train_envs if 6000 <= e < 7000]
    train_envs_sim = [e for e in train_envs if e < 7000]
    dev_envs_common = [e for e in dev_envs if 6000 <= e < 7000]
    dev_envs_sim = [e for e in dev_envs if e < 7000]
    sim_datasets = [rl_dataset_name(run_name)]
    real_datasets = ["real"]
    trainer.set_dataset_names(sim_datasets=sim_datasets,
                              real_datasets=real_datasets)

    # ----------------------------------------------------------------------------------------------------------------
    for start_sup_epoch in range(10000):
        epfname = epoch_sup_filename(run_name,
                                     start_sup_epoch,
                                     model="stage1",
                                     domain="sim")
        path = os.path.join(get_model_dir(), str(epfname) + ".pytorch")
        if not os.path.exists(path):
            break
    if start_sup_epoch > 0:
        print(f"SUPP: CONTINUING SUP TRAINING FROM EPOCH: {start_sup_epoch}")
        load_pytorch_model(
            model_real,
            epoch_sup_filename(run_name,
                               start_sup_epoch - 1,
                               model="stage1",
                               domain="real"))
        load_pytorch_model(
            model_sim,
            epoch_sup_filename(run_name,
                               start_sup_epoch - 1,
                               model="stage1",
                               domain="sim"))
        load_pytorch_model(
            model_critic,
            epoch_sup_filename(run_name,
                               start_sup_epoch - 1,
                               model="critic",
                               domain="critic"))
        trainer.set_start_epoch(start_sup_epoch)

    # ----------------------------------------------------------------------------------------------------------------
    print("SUPP: Beginning training...")
    for epoch in range(start_sup_epoch, num_epochs):
        # Tell the RL process that a new Stage 1 model is ready for loading
        print("SUPP: Sending model to RL")
        model_sim.reset()
        rl_process_conn.send(
            ["stage1_model_state_dict",
             model_sim.state_dict()])
        if DEBUG_RL:
            while True:
                sleep(1)

        if not sim_seed_dataset:
            ddir = get_dataset_dir(rl_dataset_name(run_name))
            os.makedirs(ddir, exist_ok=True)
            while len(os.listdir(ddir)) < 20:
                print("SUPP: Waiting for rollouts to appear")
                sleep(3)

        print("SUPP: Beginning Epoch")
        train_loss = trainer.train_epoch(env_list_common=train_envs_common,
                                         env_list_sim=train_envs_sim,
                                         eval=False)
        test_loss = trainer.train_epoch(env_list_common=dev_envs_common,
                                        env_list_sim=dev_envs_sim,
                                        eval=True)
        print("SUPP: Epoch", epoch, "train_loss:", train_loss, "test_loss:",
              test_loss)
        save_pytorch_model(
            model_real,
            epoch_sup_filename(run_name, epoch, model="stage1", domain="real"))
        save_pytorch_model(
            model_sim,
            epoch_sup_filename(run_name, epoch, model="stage1", domain="sim"))
        save_pytorch_model(
            model_critic,
            epoch_sup_filename(run_name,
                               epoch,
                               model="critic",
                               domain="critic"))
Beispiel #5
0
def train_rl_worker(sup_process_conn):
    P.initialize_experiment()
    setup = P.get_current_parameters()["Setup"]
    setup["trajectory_length"] = setup["rl_trajectory_length"]
    run_name = setup["run_name"]
    rlsup = P.get_current_parameters()["RLSUP"]
    params = P.get_current_parameters()["RL"]
    num_rl_epochs = params["num_epochs"]
    # These need to be distinguished between supervised and RL because supervised trains on ALL envs, RL only on 6000-7000
    setup["env_range_start"] = setup["rl_env_range_start"]
    setup["env_range_end"] = setup["rl_env_range_end"]
    rl_device = rlsup.get("rl_device", "cuda:0")

    trainer = TrainerRL(params=dict_merge(setup, params),
                        save_rollouts_to_dataset=rl_dataset_name(run_name),
                        device=rl_device)

    # -------------------------------------------------------------------------------------
    # TODO: Continue (including figure out how to initialize Supervised Stage 1 real/sim/critic and RL Stage 2 policy
    start_rl_epoch = 0
    for start_rl_epoch in range(num_rl_epochs):
        epfname = epoch_rl_filename(run_name, start_rl_epoch, model="full")
        path = os.path.join(get_model_dir(), str(epfname) + ".pytorch")
        if not os.path.exists(path):
            break
    if start_rl_epoch > 0:
        print(f"RLP: CONTINUING RL TRAINING FROM EPOCH: {start_rl_epoch}")
        load_pytorch_model(
            trainer.full_model,
            epoch_rl_filename(run_name, start_rl_epoch - 1, model="full"))
        trainer.set_start_epoch(start_rl_epoch)
    # Wait for supervised process to send it's model
    sleep(2)

    # -------------------------------------------------------------------------------------

    print("RLP: Beginning training...")
    for rl_epoch in range(start_rl_epoch, num_rl_epochs):
        # Get the latest Stage 1 model. Halt on the first epoch so that we can actually initialize the Stage 1
        new_stage1_model_state_dict = receive_stage1_state(
            sup_process_conn, halt=(rl_epoch == start_rl_epoch))
        if new_stage1_model_state_dict:
            print(f"RLP: Re-loading latest Stage 1 model")
            trainer.reload_stage1(new_stage1_model_state_dict)

        train_reward, metrics = trainer.train_epoch(epoch_num=rl_epoch,
                                                    eval=False,
                                                    envs="train")
        dev_reward, metrics = trainer.train_epoch(epoch_num=rl_epoch,
                                                  eval=True,
                                                  envs="dev")

        print("RLP: RL Epoch", rl_epoch, "train reward:", train_reward,
              "dev reward:", dev_reward)
        save_pytorch_model(trainer.full_model,
                           epoch_rl_filename(run_name, rl_epoch, model="full"))
        save_pytorch_model(
            trainer.full_model.stage1_visitation_prediction,
            epoch_rl_filename(run_name, rl_epoch, model="stage1"))
        save_pytorch_model(
            trainer.full_model.stage2_action_generation,
            epoch_rl_filename(run_name, rl_epoch, model="stage2"))
Beispiel #6
0
def train_supervised():
    initialize_experiment()

    setup = get_current_parameters()["Setup"]
    supervised_params = get_current_parameters()["Supervised"]
    num_epochs = supervised_params["num_epochs"]

    model, model_loaded = load_model()
    # import pdb; pdb.set_trace()
    # import pickle
    # with open('/storage/dxsun/model_input.pickle', 'rb') as f: data = pickle.load(f)
    # g = model(data['images'], data['states'], data['instructions'], data['instr_lengths'], data['has_obs'], data['plan'], data['save_maps_only'], data['pos_enc'], data['noisy_poses'], data['start_poses'], data['firstseg'])
    print("model:", model)
    print("model type:", type(model))
    print("Loading data")
    # import pdb;pdb.set_trace()
    train_envs, dev_envs, test_envs = get_all_env_id_lists(
        max_envs=setup["max_envs"])
    if "split_train_data" in supervised_params and supervised_params[
            "split_train_data"]:
        split_name = supervised_params["train_data_split"]
        split = load_env_split()[split_name]
        train_envs = [env_id for env_id in train_envs if env_id in split]
        print("Using " + str(len(train_envs)) + " envs from dataset split: " +
              split_name)

    filename = "supervised_" + setup["model"] + "_" + setup["run_name"]

    # Code looks weird here because load_pytorch_model adds ".pytorch" to end of path, but
    # file_exists doesn't
    model_path = "tmp/" + filename + "_epoch_" + str(
        supervised_params["start_epoch"])
    model_path_with_extension = model_path + ".pytorch"
    print("model path:", model_path_with_extension)
    if supervised_params["start_epoch"] > 0:
        if file_exists(model_path_with_extension):
            print("THE FILE EXISTS code1")
            load_pytorch_model(model, model_path)
        else:
            print("Couldn't continue training. Model file doesn't exist at:")
            print(model_path_with_extension)
            exit(-1)
    # import pdb;pdb.set_trace()
    ## If you just want to use the pretrained model
    # load_pytorch_model(model, "supervised_pvn_stage1_train_corl_pvn_stage1")

    # all_train_data, all_test_data = data_io.train_data.load_supervised_data(max_envs=100)
    if setup["restore_weights_name"]:
        restore_pretrained_weights(model, setup["restore_weights_name"],
                                   setup["fix_restored_weights"])

    # Add a tensorboard logger to the model and trainer
    tensorboard_dir = get_current_parameters(
    )['Environment']['tensorboard_dir']
    logger = Logger(tensorboard_dir)
    model.logger = logger
    if hasattr(model, "goal_good_criterion"):
        print("gave logger to goal evaluator")
        model.goal_good_criterion.logger = logger

    trainer = Trainer(model,
                      epoch=supervised_params["start_epoch"],
                      name=setup["model"],
                      run_name=setup["run_name"])

    trainer.logger = logger

    # import pdb;pdb.set_trace()
    print("Beginning training...")
    best_test_loss = 1000

    continue_epoch = supervised_params["start_epoch"] + 1 if supervised_params[
        "start_epoch"] > 0 else 0
    rng = range(0, num_epochs)
    print("filename:", filename)

    import pdb
    pdb.set_trace()

    for epoch in rng:
        # import pdb;pdb.set_trace()
        train_loss = trainer.train_epoch(train_data=None,
                                         train_envs=train_envs,
                                         eval=False)
        # train_loss = trainer.train_epoch(train_data=all_train_data, train_envs=train_envs, eval=False)

        trainer.model.correct_goals = 0
        trainer.model.total_goals = 0

        test_loss = trainer.train_epoch(train_data=None,
                                        train_envs=dev_envs,
                                        eval=True)

        print("GOALS: ", trainer.model.correct_goals,
              trainer.model.total_goals)

        if test_loss < best_test_loss:
            best_test_loss = test_loss
            save_pytorch_model(trainer.model, filename)
            print("Saved model in:", filename)
        print("Epoch", epoch, "train_loss:", train_loss, "test_loss:",
              test_loss)
        save_pytorch_model(trainer.model,
                           "tmp/" + filename + "_epoch_" + str(epoch))
        if hasattr(trainer.model, "save"):
            trainer.model.save(epoch)
        save_pretrained_weights(trainer.model, setup["run_name"])
Beispiel #7
0
def train_dagger_simple():
    # ----------------------------------------------------------------------------------------------------------------
    # Load params and configure stuff

    P.initialize_experiment()
    params = P.get_current_parameters()["SimpleDagger"]
    setup = P.get_current_parameters()["Setup"]
    num_iterations = params["num_iterations"]
    sim_seed_dataset = params.get("sim_seed_dataset")
    run_name = setup["run_name"]
    device = params.get("device", "cuda:1")
    dataset_limit = params.get("dataset_size_limit_envs")
    seed_count = params.get("seed_count")

    # Trigger rebuild if necessary before going into all the threads and processes
    _ = get_restricted_env_id_lists(full=True)

    # Initialize the dataset
    if sim_seed_dataset:
        copy_seed_dataset(from_dataset=sim_seed_dataset,
                          to_dataset=dagger_dataset_name(run_name),
                          seed_count=seed_count or dataset_limit)
        gap = 0
    else:
        # TODO: Refactor this into a prompt function
        data_path = get_dataset_dir(dagger_dataset_name(run_name))
        if os.path.exists(data_path):
            print("DATASET EXISTS! Continue where left off?")
            c = input(" (y/n) >>> ")
            if c != "y":
                raise ValueError(
                    f"Not continuing: Dataset {data_path} exists. Delete it if you like and try again"
                )
        else:
            os.makedirs(data_path, exist_ok=True)
        gap = dataset_limit - len(os.listdir(data_path))

    print("SUPP: Loading data")
    train_envs, dev_envs, test_envs = get_restricted_env_id_lists()

    # ----------------------------------------------------------------------------------------------------------------
    # Load / initialize model

    model = load_model(setup["model"], setup["model_file"],
                       domain="sim")[0].to(device)
    oracle = load_model("oracle")[0]

    # ----------------------------------------------------------------------------------------------------------------
    # Continue where we left off - load the model and set the iteration/epoch number

    for start_iteration in range(10000):
        epfname = epoch_dag_filename(run_name, start_iteration)
        path = os.path.join(get_model_dir(), str(epfname) + ".pytorch")
        if not os.path.exists(path):
            break
    if start_iteration > 0:
        print(
            f"DAG: CONTINUING DAGGER TRAINING FROM ITERATION: {start_iteration}"
        )
        load_pytorch_model(model,
                           epoch_dag_filename(run_name, start_iteration - 1))

    # ----------------------------------------------------------------------------------------------------------------
    # Intialize trainer

    trainer = Trainer(model,
                      epoch=start_iteration,
                      name=setup["model"],
                      run_name=setup["run_name"])
    trainer.set_dataset_names([dagger_dataset_name(run_name)])

    # ----------------------------------------------------------------------------------------------------------------
    # Initialize policy roller

    roller = SimpleParallelPolicyRoller(
        num_workers=params["num_workers"],
        device=params["device"],
        policy_name=setup["model"],
        policy_file=setup["model_file"],
        oracle=oracle,
        dataset_save_name=dagger_dataset_name(run_name),
        no_reward=True)
    rollout_sampler = RolloutSampler(roller)

    # ----------------------------------------------------------------------------------------------------------------
    # Train DAgger - loop over iteartions, in each, prune, rollout and train an epoch

    print("SUPP: Beginning training...")
    for iteration in range(start_iteration, num_iterations):
        print(f"DAG: Starting iteration {iteration}")

        # Remove extra rollouts to keep within DAggerFM limit
        prune_dataset(run_name, dataset_limit)

        # Rollout and collect more data for training and evaluation
        policy_state = model.get_policy_state()
        rollout_sampler.sample_n_rollouts(
            n=gap if iteration == 0 else params["train_envs_per_iteration"],
            policy_state=policy_state,
            sample=False,
            envs="train",
            dagger_beta=dagger_beta(params, iteration))

        eval_rollouts = rollout_sampler.sample_n_rollouts(
            n=params["eval_envs_per_iteration"],
            policy_state=policy_state,
            sample=False,
            envs="dev",
            dagger_beta=0)

        # Kill airsim instances so that they don't take up GPU memory and in general slow things down during training
        roller.kill_airsim()

        # Evaluate success / metrics and save to tensorboard
        if setup["eval_nl"]:
            evaler = DataEvalNL(run_name,
                                entire_trajectory=False,
                                save_images=False)
            evaler.evaluate_dataset(eval_rollouts)
            results = evaler.get_results()
            print("Results:", results)
            evaler.write_summaries(setup["run_name"], "dagger_eval", iteration)

        # Do one epoch of supervised training
        print("SUPP: Beginning Epoch")
        train_loss = trainer.train_epoch(train_envs=train_envs, eval=False)
        #test_loss = trainer.train_epoch(env_list_common=dev_envs_common, env_list_sim=dev_envs_sim, eval=True)

        # Save the model to file
        print("SUPP: Epoch", iteration, "train_loss:", train_loss)
        save_pytorch_model(model, epoch_dag_filename(run_name, iteration))
Beispiel #8
0
def load_model(model_file_override=None,
               real=False,
               model_name_override=False):

    setup = P.get_current_parameters()["Setup"]
    model_name = model_name_override or setup["model"]
    model_file = model_file_override or setup["model_file"] or None
    cuda = setup["cuda"]
    run_name = setup["run_name"]

    model = None
    pytorch_model = False

    # -----------------------------------------------------------------------------------------------------------------
    # Oracles / baselines that ignore images
    # -----------------------------------------------------------------------------------------------------------------

    if model_name == "oracle":
        rollout_params = get_current_parameters()["Rollout"]
        if rollout_params["oracle_type"] == "SimpleCarrotPlanner":
            model = SimpleCarrotPlanner()
            print("Using simple carrot planner")
        elif rollout_params["oracle_type"] == "BasicCarrotPlanner":
            model = BasicCarrotPlanner()
            print("Using basic carrot planner")
        elif rollout_params["oracle_type"] == "FancyCarrotPlanner":
            model = FancyCarrotPlanner()
            print("Using fancy carrot planner")
        else:
            print("UNKNOWN ORACLE: ", rollout_params["OracleType"])
            exit(-1)
    elif model_name == "baseline_straight":
        model = BaselineStraight()
    elif model_name == "baseline_stop":
        model = BaselineStop()

    # -----------------------------------------------------------------------------------------------------------------
    # FASTER RSS 2018 Resubmission Model
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "gsmn":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True)
        pytorch_model = True
    elif model_name == "gsmn_wo_jlang":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=False)
        pytorch_model = True
    elif model_name == "gsmn_wo_jgnd":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=True,
                         aux_grounding_map=False,
                         aux_goal_map=True,
                         aux_lang=True)
        pytorch_model = True
    elif model_name == "gsmn_wo_jclass":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=False,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True)
        pytorch_model = True
    elif model_name == "gsmn_wo_jgoal":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=False,
                         aux_lang=True)
        pytorch_model = True

    elif model_name == "gsmn_w_posnoise":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True,
                         pos_noise=True,
                         rot_noise=False)
        pytorch_model = True
    elif model_name == "gsmn_w_rotnoise":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True,
                         pos_noise=False,
                         rot_noise=True)
        pytorch_model = True
    elif model_name == "gsmn_w_bothnoise":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True,
                         pos_noise=True,
                         rot_noise=True)
        pytorch_model = True

    elif model_name == "gs_fpv":
        model = ModelGSFPV(run_name,
                           aux_class_features=True,
                           aux_grounding_features=True,
                           aux_lang=True,
                           recurrence=False)
        pytorch_model = True
    elif model_name == "gs_fpv_mem":
        model = ModelGSFPV(run_name,
                           aux_class_features=True,
                           aux_grounding_features=True,
                           aux_lang=True,
                           recurrence=True)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL 2018 Model
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "sm_traj_nav_ratio":
        model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV)
        pytorch_model = True
    elif model_name == "sm_traj_nav_ratio_path":
        model = ModelTrajectoryProbRatio(run_name,
                                         model_class=mtpr.PVN_STAGE1_ONLY)
        pytorch_model = True

    elif model_name == "action_gtr":
        model = ModelTrajectoryToAction(run_name)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL 2018 Refactored
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "pvn_full":
        model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV)
        pytorch_model = True
    elif model_name == "pvn_stage1":
        model = ModelTrajectoryProbRatio(run_name,
                                         model_class=mtpr.PVN_STAGE1_ONLY)
        pytorch_model = True
    elif model_name == "pvn_stage2":
        model = ModelTrajectoryToAction(run_name)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL 2018 Top-Down Full Observability Models
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "top_down_goal_batched":
        model = ModelTopDownPathGoalPredictorBatched(run_name)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL Baselines
    # -----------------------------------------------------------------------------------------------------------------
    elif model_name == "chaplot":
        model = ModelChaplot(run_name)
        pytorch_model = True

    elif model_name == "misra2017":
        model = ModelMisra2017(run_name)
        pytorch_model = True

    model_loaded = False
    if pytorch_model:
        n_params = get_n_params(model)
        n_params_tr = get_n_trainable_params(model)
        print("Loaded PyTorch model!")
        print("Number of model parameters: " + str(n_params))
        print("Trainable model parameters: " + str(n_params_tr))
        model.init_weights()
        model.eval()
        if model_file:
            load_pytorch_model(model, model_file)
            print("Loaded previous model: ", model_file)
            model_loaded = True
        if cuda:
            model = model.cuda()

    return model, model_loaded
Beispiel #9
0
def load_model(model_name_override=False,
               model_file_override=None,
               domain="sim"):

    setup = P.get_current_parameters()["Setup"]
    model_name = model_name_override or setup["model"]
    model_file = model_file_override or setup["model_file"] or None
    # TODO: Move this stuff elsewhere and tidy up the model
    perception_model_file = setup.get("perception_model_file") or None
    perception_model_real = setup.get("perception_model_real") or None
    cuda = setup["cuda"]
    run_name = setup["run_name"]

    model = None
    pytorch_model = False

    # -----------------------------------------------------------------------------------------------------------------
    # Oracles / baselines that ignore images
    # -----------------------------------------------------------------------------------------------------------------

    if model_name == "oracle":
        rollout_params = get_current_parameters()["Rollout"]
        if rollout_params["oracle_type"] == "SimpleCarrotPlanner":
            model = SimpleCarrotPlanner()
            print("Using simple carrot planner")
        elif rollout_params["oracle_type"] == "BasicCarrotPlanner":
            model = BasicCarrotPlanner()
            print("Using basic carrot planner")
        elif rollout_params["oracle_type"] == "FancyCarrotPlanner":
            model = FancyCarrotPlanner()
            print("Using fancy carrot planner")
        else:
            print("UNKNOWN ORACLE: ", rollout_params["OracleType"])
            exit(-1)
    elif model_name == "average":
        model = BaselineAverage()
    elif model_name == "stop":
        model = BaselineStop()

    # -----------------------------------------------------------------------------------------------------------------
    # FASTER RSS 2018 Resubmission Model
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "gsmn":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True)
        pytorch_model = True
    elif model_name == "gsmn_wo_jlang":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=False)
        pytorch_model = True
    elif model_name == "gsmn_wo_jgnd":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=True,
                         aux_grounding_map=False,
                         aux_goal_map=True,
                         aux_lang=True)
        pytorch_model = True
    elif model_name == "gsmn_wo_jclass":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=False,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True)
        pytorch_model = True
    elif model_name == "gsmn_wo_jgoal":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=False,
                         aux_lang=True)
        pytorch_model = True

    elif model_name == "gsmn_w_posnoise":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True,
                         pos_noise=True,
                         rot_noise=False)
        pytorch_model = True
    elif model_name == "gsmn_w_rotnoise":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True,
                         pos_noise=False,
                         rot_noise=True)
        pytorch_model = True
    elif model_name == "gsmn_w_bothnoise":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True,
                         pos_noise=True,
                         rot_noise=True)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # RSS Baselines
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "gs_fpv":
        model = ModelGSFPV(run_name,
                           aux_class_features=True,
                           aux_grounding_features=True,
                           aux_lang=True,
                           recurrence=False)
        pytorch_model = True
    elif model_name == "gs_fpv_mem":
        model = ModelGSFPV(run_name,
                           aux_class_features=True,
                           aux_grounding_features=True,
                           aux_lang=True,
                           recurrence=True)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # RSS Model for Cage
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "gsmn_cage":
        model = ModelRSS(run_name,
                         model_class=msrg.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=False)
        pytorch_model = True

    elif model_name == "gsmn_bidomain":
        model = ModelGSMNBiDomain(run_name, model_instance_name=domain)
        pytorch_model = True

    elif model_name == "gsmn_critic":
        model = ModelGsmnCritic(run_name)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL 2018 Model
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "sm_traj_nav_ratio":
        model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV)
        pytorch_model = True
    elif model_name == "sm_traj_nav_ratio_path":
        model = ModelTrajectoryProbRatio(run_name,
                                         model_class=mtpr.PVN_STAGE1_ONLY)
        pytorch_model = True

    elif model_name == "action_gtr":
        model = ModelTrajectoryToAction(run_name)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL 2018 Refactored
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "pvn_full":
        model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV)
        pytorch_model = True
    elif model_name == "pvn_stage1":
        model = ModelTrajectoryProbRatio(run_name,
                                         model_class=mtpr.PVN_STAGE1_ONLY)
        pytorch_model = True
    elif model_name == "pvn_stage2":
        model = ModelTrajectoryToAction(run_name)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL 2018 Top-Down Full Observability Models
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "top_down_goal_batched":
        model = ModelTopDownPathGoalPredictorBatched(run_name)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL Model for cage (bidomain)
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "pvn_original_stage1_bidomain":
        model = PVN_Stage1_Bidomain_Original(run_name, domain=domain)
        pytorch_model = True

    elif model_name == "pvn_stage1_bidomain":
        model = PVN_Stage1_Bidomain(run_name, domain=domain)
        pytorch_model = True

    elif model_name == "pvn_stage2_bidomain":
        model = PVN_Stage2_Bidomain(run_name, model_instance_name=domain)
        pytorch_model = True

    elif model_name == "pvn_stage2_actor_critic":
        model = PVN_Stage2_ActorCritic(run_name, model_instance_name=domain)
        pytorch_model = True

    elif model_name == "pvn_stage1_critic":
        model = PVN_Stage1_Critic(run_name)
        pytorch_model = True

    elif model_name == "pvn_stage1_critic_big":
        model = PVN_Stage1_Critic_Big(run_name)
        pytorch_model = True

    elif model_name == "pvn_full_bidomain":
        model = PVN_Wrapper_Bidomain(run_name,
                                     model_instance_name=domain,
                                     oracle_stage1=False)
        pytorch_model = True

    elif model_name == "pvn_full_bidomain_ground_truth":
        model = PVN_Wrapper_Bidomain(run_name,
                                     model_instance_name=domain,
                                     oracle_stage1=True)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------

    model_loaded = False
    if pytorch_model:
        n_params = get_n_params(model)
        n_params_tr = get_n_trainable_params(model)
        print("Loaded PyTorch model!")
        print("Number of model parameters: " + str(n_params))
        print("Trainable model parameters: " + str(n_params_tr))
        model.init_weights()
        model.eval()
        if model_file:
            load_pytorch_model(model, model_file, pytorch3to4=True)
            print("Loaded previous model: ", model_file)
            model_loaded = True

        if cuda:
            model = model.cuda()

    return model, model_loaded