예제 #1
0
def train_supervised():
    initialize_experiment()

    setup = get_current_parameters()["Setup"]
    supervised_params = get_current_parameters()["Supervised"]
    num_epochs = supervised_params["num_epochs"]

    model, model_loaded = load_model()

    print("Loading data")
    train_envs, dev_envs, test_envs = get_all_env_id_lists(max_envs=setup["max_envs"])

    if "split_train_data" in supervised_params and supervised_params["split_train_data"]:
        split_name = supervised_params["train_data_split"]
        split = load_env_split()[split_name]
        train_envs = [env_id for env_id in train_envs if env_id in split]
        print("Using " + str(len(train_envs)) + " envs from dataset split: " + split_name)

    filename = "supervised_" + setup["model"] + "_" + setup["run_name"]
    start_filename = "tmp/" + filename + "_epoch_" + str(supervised_params["start_epoch"])
    if supervised_params["start_epoch"] > 0:
        if file_exists(start_filename):
            load_pytorch_model(model, start_filename)
        else:
            print("Couldn't continue training. Model file doesn't exist at:")
            print(start_filename)
            exit(-1)

    if setup["restore_weights_name"]:
        restore_pretrained_weights(model, setup["restore_weights_name"], setup["fix_restored_weights"])

    trainer = Trainer(model, epoch=supervised_params["start_epoch"], name=setup["model"], run_name=setup["run_name"])

    print("Beginning training...")
    best_test_loss = 1000
    for epoch in range(num_epochs):
        train_loss = trainer.train_epoch(train_data=None, train_envs=train_envs, eval=False)

        trainer.model.correct_goals = 0
        trainer.model.total_goals = 0

        test_loss = trainer.train_epoch(train_data=None, train_envs=dev_envs, eval=True)

        print("GOALS: ", trainer.model.correct_goals, trainer.model.total_goals)

        if test_loss < best_test_loss:
            best_test_loss = test_loss
            save_pytorch_model(trainer.model, filename)
            print("Saved model in:", filename)
        print ("Epoch", epoch, "train_loss:", train_loss, "test_loss:", test_loss)
        save_pytorch_model(trainer.model, "tmp/" + filename + "_epoch_" + str(epoch))
        if hasattr(trainer.model, "save"):
            trainer.model.save(epoch)
        save_pretrained_weights(trainer.model, setup["run_name"])
예제 #2
0
파일: train_rl.py 프로젝트: pianpwk/drif
def train_rl():
    initialize_experiment()

    setup = get_current_parameters()["Setup"]
    params = get_current_parameters()["RL"]

    print("Loading data")
    train_envs, dev_envs, test_envs = get_restricted_env_id_lists()

    filename = "rl_" + setup["model"] + "_" + setup["run_name"]

    trainer = TrainerRL(params=dict_merge(setup, params))

    for start_epoch in range(10000):
        epfname = epoch_filename(filename, start_epoch)
        path = os.path.join(get_model_dir(), str(epfname) + ".pytorch")
        if not os.path.exists(path):
            break

    if start_epoch > 0:
        print(f"CONTINUING RL TRAINING FROM EPOCH: {start_epoch}")
        load_pytorch_model(trainer.full_model,
                           epoch_filename(filename, start_epoch - 1))
        trainer.set_start_epoch(start_epoch)

    print("Beginning training...")
    best_dev_reward = -1e+10
    for epoch in range(start_epoch, 10000):
        train_reward, metrics = trainer.train_epoch(eval=False, envs="train")
        # TODO: Test on just a few dev environments
        # TODO: Take most likely or mean action when testing
        dev_reward, metrics = trainer.train_epoch(eval=True, envs="dev")
        #dev_reward, metrics = trainer.train_epoch(eval=True, envs="dev")
        dev_reward = 0

        #if dev_reward >= best_dev_reward:
        #    best_dev_reward = dev_reward
        #    save_pytorch_model(trainer.full_model, filename)
        #    print("Saved model in:", filename)

        print("Epoch", epoch, "train reward:", train_reward, "dev reward:",
              dev_reward)
        save_pytorch_model(trainer.full_model, epoch_filename(filename, epoch))
        if hasattr(trainer.full_model, "save"):
            trainer.full_model.save(epoch)
예제 #3
0
def train_top_down_pred(args, max_epoch=SUPERVISED_EPOCHS):
    initialize_experiment(args.run_name, args.setup_name)

    model, model_loaded = load_model()

    # TODO: Get batch size from global parameter server when it exists
    batch_size = 1 if \
        args.model == "top_down" or \
        args.model == "top_down_prior" or \
        args.model == "top_down_sm" or \
        args.model == "top_down_pretrain" or \
        args.model == "top_down_goal_pretrain" or \
        args.model == "top_down_nav" or \
        args.model == "top_down_cond" \
        else BATCH_SIZE

    lr = 0.001  # * batch_size
    trainer = Trainer(model,
                      epoch=args.start_epoch,
                      name=args.model,
                      run_name=args.run_name)

    train_envs, dev_envs, test_envs = get_all_env_id_lists(
        max_envs=args.max_envs)

    filename = "top_down_" + args.model + "_" + args.run_name

    if args.restore_weights_name is not None:
        restore_pretrained_weights(model, args.restore_weights_name,
                                   args.fix_restored_weights)

    print("Beginning training...")
    best_test_loss = 1000

    validation_loss = []

    for epoch in range(SUPERVISED_EPOCHS):
        train_loss = -1

        if not args.eval_pretrain:
            train_loss = trainer.train_epoch(train_envs=train_envs, eval=False)

        test_loss = trainer.train_epoch(train_envs=dev_envs, eval=True)
        validation_loss.append([epoch, test_loss])

        if not args.eval_pretrain:
            if test_loss < best_test_loss:
                best_test_loss = test_loss
                save_pytorch_model(trainer.model, filename)
                print("Saved model in:", filename)

            print("Epoch", epoch, "train_loss:", train_loss, "test_loss:",
                  test_loss)
            save_pytorch_model(trainer.model,
                               "tmp/" + filename + "_epoch_" + str(epoch))
            save_pretrained_weights(trainer.model, args.run_name)

        else:
            break

        if max_epoch is not None and epoch > max_epoch:
            print("Reached epoch limit!")
            break

    test_loss_dir = get_model_dir(
    ) + "/test_loss/" + filename + "_test_loss.csv"
    validation_loss = pd.DataFrame(validation_loss,
                                   columns=['epoch', "test_loss"])
    validation_loss.to_csv(test_loss_dir, index=False)
예제 #4
0
 def save(self, epoch):
     filename = self.params[
         "map_to_action_file"] + "_" + self.run_name + "_" + str(epoch)
     save_pytorch_model(self.map_to_action, filename)
     print("Saved action model to " + filename)
예제 #5
0
def train_supervised_worker(rl_process_conn):
    P.initialize_experiment()
    setup = P.get_current_parameters()["Setup"]
    rlsup = P.get_current_parameters()["RLSUP"]
    setup["trajectory_length"] = setup["sup_trajectory_length"]
    run_name = setup["run_name"]
    supervised_params = P.get_current_parameters()["Supervised"]
    num_epochs = supervised_params["num_epochs"]
    sup_device = rlsup.get("sup_device", "cuda:1")

    model_oracle_critic = None

    print("SUPP: Loading data")
    train_envs, dev_envs, test_envs = get_restricted_env_id_lists()

    # Load the starter model and save it at epoch 0
    # Supervised worker to use GPU 1, RL will use GPU 0. Simulators run on GPU 2
    model_sim = load_model(setup["sup_model"],
                           setup["sim_model_file"],
                           domain="sim")[0].to(sup_device)
    model_real = load_model(setup["sup_model"],
                            setup["real_model_file"],
                            domain="real")[0].to(sup_device)
    model_critic = load_model(setup["sup_critic_model"],
                              setup["critic_model_file"])[0].to(sup_device)

    # ----------------------------------------------------------------------------------------------------------------

    print("SUPP: Initializing trainer")
    rlsup_params = P.get_current_parameters()["RLSUP"]
    sim_seed_dataset = rlsup_params.get("sim_seed_dataset")

    # TODO: Figure if 6000 or 7000 here
    trainer = TrainerBidomainBidata(model_real,
                                    model_sim,
                                    model_critic,
                                    model_oracle_critic,
                                    epoch=0)
    train_envs_common = [e for e in train_envs if 6000 <= e < 7000]
    train_envs_sim = [e for e in train_envs if e < 7000]
    dev_envs_common = [e for e in dev_envs if 6000 <= e < 7000]
    dev_envs_sim = [e for e in dev_envs if e < 7000]
    sim_datasets = [rl_dataset_name(run_name)]
    real_datasets = ["real"]
    trainer.set_dataset_names(sim_datasets=sim_datasets,
                              real_datasets=real_datasets)

    # ----------------------------------------------------------------------------------------------------------------
    for start_sup_epoch in range(10000):
        epfname = epoch_sup_filename(run_name,
                                     start_sup_epoch,
                                     model="stage1",
                                     domain="sim")
        path = os.path.join(get_model_dir(), str(epfname) + ".pytorch")
        if not os.path.exists(path):
            break
    if start_sup_epoch > 0:
        print(f"SUPP: CONTINUING SUP TRAINING FROM EPOCH: {start_sup_epoch}")
        load_pytorch_model(
            model_real,
            epoch_sup_filename(run_name,
                               start_sup_epoch - 1,
                               model="stage1",
                               domain="real"))
        load_pytorch_model(
            model_sim,
            epoch_sup_filename(run_name,
                               start_sup_epoch - 1,
                               model="stage1",
                               domain="sim"))
        load_pytorch_model(
            model_critic,
            epoch_sup_filename(run_name,
                               start_sup_epoch - 1,
                               model="critic",
                               domain="critic"))
        trainer.set_start_epoch(start_sup_epoch)

    # ----------------------------------------------------------------------------------------------------------------
    print("SUPP: Beginning training...")
    for epoch in range(start_sup_epoch, num_epochs):
        # Tell the RL process that a new Stage 1 model is ready for loading
        print("SUPP: Sending model to RL")
        model_sim.reset()
        rl_process_conn.send(
            ["stage1_model_state_dict",
             model_sim.state_dict()])
        if DEBUG_RL:
            while True:
                sleep(1)

        if not sim_seed_dataset:
            ddir = get_dataset_dir(rl_dataset_name(run_name))
            os.makedirs(ddir, exist_ok=True)
            while len(os.listdir(ddir)) < 20:
                print("SUPP: Waiting for rollouts to appear")
                sleep(3)

        print("SUPP: Beginning Epoch")
        train_loss = trainer.train_epoch(env_list_common=train_envs_common,
                                         env_list_sim=train_envs_sim,
                                         eval=False)
        test_loss = trainer.train_epoch(env_list_common=dev_envs_common,
                                        env_list_sim=dev_envs_sim,
                                        eval=True)
        print("SUPP: Epoch", epoch, "train_loss:", train_loss, "test_loss:",
              test_loss)
        save_pytorch_model(
            model_real,
            epoch_sup_filename(run_name, epoch, model="stage1", domain="real"))
        save_pytorch_model(
            model_sim,
            epoch_sup_filename(run_name, epoch, model="stage1", domain="sim"))
        save_pytorch_model(
            model_critic,
            epoch_sup_filename(run_name,
                               epoch,
                               model="critic",
                               domain="critic"))
예제 #6
0
def train_rl_worker(sup_process_conn):
    P.initialize_experiment()
    setup = P.get_current_parameters()["Setup"]
    setup["trajectory_length"] = setup["rl_trajectory_length"]
    run_name = setup["run_name"]
    rlsup = P.get_current_parameters()["RLSUP"]
    params = P.get_current_parameters()["RL"]
    num_rl_epochs = params["num_epochs"]
    # These need to be distinguished between supervised and RL because supervised trains on ALL envs, RL only on 6000-7000
    setup["env_range_start"] = setup["rl_env_range_start"]
    setup["env_range_end"] = setup["rl_env_range_end"]
    rl_device = rlsup.get("rl_device", "cuda:0")

    trainer = TrainerRL(params=dict_merge(setup, params),
                        save_rollouts_to_dataset=rl_dataset_name(run_name),
                        device=rl_device)

    # -------------------------------------------------------------------------------------
    # TODO: Continue (including figure out how to initialize Supervised Stage 1 real/sim/critic and RL Stage 2 policy
    start_rl_epoch = 0
    for start_rl_epoch in range(num_rl_epochs):
        epfname = epoch_rl_filename(run_name, start_rl_epoch, model="full")
        path = os.path.join(get_model_dir(), str(epfname) + ".pytorch")
        if not os.path.exists(path):
            break
    if start_rl_epoch > 0:
        print(f"RLP: CONTINUING RL TRAINING FROM EPOCH: {start_rl_epoch}")
        load_pytorch_model(
            trainer.full_model,
            epoch_rl_filename(run_name, start_rl_epoch - 1, model="full"))
        trainer.set_start_epoch(start_rl_epoch)
    # Wait for supervised process to send it's model
    sleep(2)

    # -------------------------------------------------------------------------------------

    print("RLP: Beginning training...")
    for rl_epoch in range(start_rl_epoch, num_rl_epochs):
        # Get the latest Stage 1 model. Halt on the first epoch so that we can actually initialize the Stage 1
        new_stage1_model_state_dict = receive_stage1_state(
            sup_process_conn, halt=(rl_epoch == start_rl_epoch))
        if new_stage1_model_state_dict:
            print(f"RLP: Re-loading latest Stage 1 model")
            trainer.reload_stage1(new_stage1_model_state_dict)

        train_reward, metrics = trainer.train_epoch(epoch_num=rl_epoch,
                                                    eval=False,
                                                    envs="train")
        dev_reward, metrics = trainer.train_epoch(epoch_num=rl_epoch,
                                                  eval=True,
                                                  envs="dev")

        print("RLP: RL Epoch", rl_epoch, "train reward:", train_reward,
              "dev reward:", dev_reward)
        save_pytorch_model(trainer.full_model,
                           epoch_rl_filename(run_name, rl_epoch, model="full"))
        save_pytorch_model(
            trainer.full_model.stage1_visitation_prediction,
            epoch_rl_filename(run_name, rl_epoch, model="stage1"))
        save_pytorch_model(
            trainer.full_model.stage2_action_generation,
            epoch_rl_filename(run_name, rl_epoch, model="stage2"))
예제 #7
0
def train_supervised():
    initialize_experiment()

    setup = get_current_parameters()["Setup"]
    supervised_params = get_current_parameters()["Supervised"]
    num_epochs = supervised_params["num_epochs"]

    model, model_loaded = load_model()
    # import pdb; pdb.set_trace()
    # import pickle
    # with open('/storage/dxsun/model_input.pickle', 'rb') as f: data = pickle.load(f)
    # g = model(data['images'], data['states'], data['instructions'], data['instr_lengths'], data['has_obs'], data['plan'], data['save_maps_only'], data['pos_enc'], data['noisy_poses'], data['start_poses'], data['firstseg'])
    print("model:", model)
    print("model type:", type(model))
    print("Loading data")
    # import pdb;pdb.set_trace()
    train_envs, dev_envs, test_envs = get_all_env_id_lists(
        max_envs=setup["max_envs"])
    if "split_train_data" in supervised_params and supervised_params[
            "split_train_data"]:
        split_name = supervised_params["train_data_split"]
        split = load_env_split()[split_name]
        train_envs = [env_id for env_id in train_envs if env_id in split]
        print("Using " + str(len(train_envs)) + " envs from dataset split: " +
              split_name)

    filename = "supervised_" + setup["model"] + "_" + setup["run_name"]

    # Code looks weird here because load_pytorch_model adds ".pytorch" to end of path, but
    # file_exists doesn't
    model_path = "tmp/" + filename + "_epoch_" + str(
        supervised_params["start_epoch"])
    model_path_with_extension = model_path + ".pytorch"
    print("model path:", model_path_with_extension)
    if supervised_params["start_epoch"] > 0:
        if file_exists(model_path_with_extension):
            print("THE FILE EXISTS code1")
            load_pytorch_model(model, model_path)
        else:
            print("Couldn't continue training. Model file doesn't exist at:")
            print(model_path_with_extension)
            exit(-1)
    # import pdb;pdb.set_trace()
    ## If you just want to use the pretrained model
    # load_pytorch_model(model, "supervised_pvn_stage1_train_corl_pvn_stage1")

    # all_train_data, all_test_data = data_io.train_data.load_supervised_data(max_envs=100)
    if setup["restore_weights_name"]:
        restore_pretrained_weights(model, setup["restore_weights_name"],
                                   setup["fix_restored_weights"])

    # Add a tensorboard logger to the model and trainer
    tensorboard_dir = get_current_parameters(
    )['Environment']['tensorboard_dir']
    logger = Logger(tensorboard_dir)
    model.logger = logger
    if hasattr(model, "goal_good_criterion"):
        print("gave logger to goal evaluator")
        model.goal_good_criterion.logger = logger

    trainer = Trainer(model,
                      epoch=supervised_params["start_epoch"],
                      name=setup["model"],
                      run_name=setup["run_name"])

    trainer.logger = logger

    # import pdb;pdb.set_trace()
    print("Beginning training...")
    best_test_loss = 1000

    continue_epoch = supervised_params["start_epoch"] + 1 if supervised_params[
        "start_epoch"] > 0 else 0
    rng = range(0, num_epochs)
    print("filename:", filename)

    import pdb
    pdb.set_trace()

    for epoch in rng:
        # import pdb;pdb.set_trace()
        train_loss = trainer.train_epoch(train_data=None,
                                         train_envs=train_envs,
                                         eval=False)
        # train_loss = trainer.train_epoch(train_data=all_train_data, train_envs=train_envs, eval=False)

        trainer.model.correct_goals = 0
        trainer.model.total_goals = 0

        test_loss = trainer.train_epoch(train_data=None,
                                        train_envs=dev_envs,
                                        eval=True)

        print("GOALS: ", trainer.model.correct_goals,
              trainer.model.total_goals)

        if test_loss < best_test_loss:
            best_test_loss = test_loss
            save_pytorch_model(trainer.model, filename)
            print("Saved model in:", filename)
        print("Epoch", epoch, "train_loss:", train_loss, "test_loss:",
              test_loss)
        save_pytorch_model(trainer.model,
                           "tmp/" + filename + "_epoch_" + str(epoch))
        if hasattr(trainer.model, "save"):
            trainer.model.save(epoch)
        save_pretrained_weights(trainer.model, setup["run_name"])
예제 #8
0
def train_dagger_simple():
    # ----------------------------------------------------------------------------------------------------------------
    # Load params and configure stuff

    P.initialize_experiment()
    params = P.get_current_parameters()["SimpleDagger"]
    setup = P.get_current_parameters()["Setup"]
    num_iterations = params["num_iterations"]
    sim_seed_dataset = params.get("sim_seed_dataset")
    run_name = setup["run_name"]
    device = params.get("device", "cuda:1")
    dataset_limit = params.get("dataset_size_limit_envs")
    seed_count = params.get("seed_count")

    # Trigger rebuild if necessary before going into all the threads and processes
    _ = get_restricted_env_id_lists(full=True)

    # Initialize the dataset
    if sim_seed_dataset:
        copy_seed_dataset(from_dataset=sim_seed_dataset,
                          to_dataset=dagger_dataset_name(run_name),
                          seed_count=seed_count or dataset_limit)
        gap = 0
    else:
        # TODO: Refactor this into a prompt function
        data_path = get_dataset_dir(dagger_dataset_name(run_name))
        if os.path.exists(data_path):
            print("DATASET EXISTS! Continue where left off?")
            c = input(" (y/n) >>> ")
            if c != "y":
                raise ValueError(
                    f"Not continuing: Dataset {data_path} exists. Delete it if you like and try again"
                )
        else:
            os.makedirs(data_path, exist_ok=True)
        gap = dataset_limit - len(os.listdir(data_path))

    print("SUPP: Loading data")
    train_envs, dev_envs, test_envs = get_restricted_env_id_lists()

    # ----------------------------------------------------------------------------------------------------------------
    # Load / initialize model

    model = load_model(setup["model"], setup["model_file"],
                       domain="sim")[0].to(device)
    oracle = load_model("oracle")[0]

    # ----------------------------------------------------------------------------------------------------------------
    # Continue where we left off - load the model and set the iteration/epoch number

    for start_iteration in range(10000):
        epfname = epoch_dag_filename(run_name, start_iteration)
        path = os.path.join(get_model_dir(), str(epfname) + ".pytorch")
        if not os.path.exists(path):
            break
    if start_iteration > 0:
        print(
            f"DAG: CONTINUING DAGGER TRAINING FROM ITERATION: {start_iteration}"
        )
        load_pytorch_model(model,
                           epoch_dag_filename(run_name, start_iteration - 1))

    # ----------------------------------------------------------------------------------------------------------------
    # Intialize trainer

    trainer = Trainer(model,
                      epoch=start_iteration,
                      name=setup["model"],
                      run_name=setup["run_name"])
    trainer.set_dataset_names([dagger_dataset_name(run_name)])

    # ----------------------------------------------------------------------------------------------------------------
    # Initialize policy roller

    roller = SimpleParallelPolicyRoller(
        num_workers=params["num_workers"],
        device=params["device"],
        policy_name=setup["model"],
        policy_file=setup["model_file"],
        oracle=oracle,
        dataset_save_name=dagger_dataset_name(run_name),
        no_reward=True)
    rollout_sampler = RolloutSampler(roller)

    # ----------------------------------------------------------------------------------------------------------------
    # Train DAgger - loop over iteartions, in each, prune, rollout and train an epoch

    print("SUPP: Beginning training...")
    for iteration in range(start_iteration, num_iterations):
        print(f"DAG: Starting iteration {iteration}")

        # Remove extra rollouts to keep within DAggerFM limit
        prune_dataset(run_name, dataset_limit)

        # Rollout and collect more data for training and evaluation
        policy_state = model.get_policy_state()
        rollout_sampler.sample_n_rollouts(
            n=gap if iteration == 0 else params["train_envs_per_iteration"],
            policy_state=policy_state,
            sample=False,
            envs="train",
            dagger_beta=dagger_beta(params, iteration))

        eval_rollouts = rollout_sampler.sample_n_rollouts(
            n=params["eval_envs_per_iteration"],
            policy_state=policy_state,
            sample=False,
            envs="dev",
            dagger_beta=0)

        # Kill airsim instances so that they don't take up GPU memory and in general slow things down during training
        roller.kill_airsim()

        # Evaluate success / metrics and save to tensorboard
        if setup["eval_nl"]:
            evaler = DataEvalNL(run_name,
                                entire_trajectory=False,
                                save_images=False)
            evaler.evaluate_dataset(eval_rollouts)
            results = evaler.get_results()
            print("Results:", results)
            evaler.write_summaries(setup["run_name"], "dagger_eval", iteration)

        # Do one epoch of supervised training
        print("SUPP: Beginning Epoch")
        train_loss = trainer.train_epoch(train_envs=train_envs, eval=False)
        #test_loss = trainer.train_epoch(env_list_common=dev_envs_common, env_list_sim=dev_envs_sim, eval=True)

        # Save the model to file
        print("SUPP: Epoch", iteration, "train_loss:", train_loss)
        save_pytorch_model(model, epoch_dag_filename(run_name, iteration))
예제 #9
0
def train_supervised_bidomain():
    P.initialize_experiment()

    setup = P.get_current_parameters()["Setup"]
    supervised_params = P.get_current_parameters()["Supervised"]
    num_epochs = supervised_params["num_epochs"]

    model_sim, _ = load_model(setup["model"], setup["sim_model_file"], domain="sim")
    model_real, _ = load_model(setup["model"], setup["real_model_file"], domain="real")
    model_critic, _ = load_model(setup["critic_model"], setup["critic_model_file"])

    if P.get_current_parameters()["Training"].get("use_oracle_critic", False):
        model_oracle_critic, _ = load_model(setup["critic_model"], setup["critic_model_file"])
        # This changes the name in the summary writer to get a different color plot
        oname = model_oracle_critic.model_name
        model_oracle_critic.set_model_name(oname + "_oracle")
        model_oracle_critic.model_name = oname
    else:
        model_oracle_critic = None

    print("Loading data")
    train_envs, dev_envs, test_envs = get_restricted_env_id_lists()

    real_filename = f"supervised_{setup['model']}_{setup['run_name']}_real"
    sim_filename  = f"supervised_{setup['model']}_{setup['run_name']}_sim"
    critic_filename = f"supervised_{setup['critic_model']}_{setup['run_name']}_critic"

    # TODO: (Maybe) Implement continuing of training

    # Bidata means that we treat Lani++ and LaniOriginal examples differently, only computing domain-adversarial stuff on Lani++
    bidata = P.get_current_parameters()["Training"].get("bidata", False)
    if bidata == "v2":
        trainer = TrainerBidomainBidata(model_real, model_sim, model_critic, model_oracle_critic, epoch=0)
        train_envs_common = [e for e in train_envs if 6000 <= e < 7000]
        train_envs_sim = train_envs
        dev_envs_common = [e for e in dev_envs if 6000 <= e < 7000]
        dev_envs_sim = dev_envs
    elif bidata:
        trainer = TrainerBidomainBidata(model_real, model_sim, model_critic, model_oracle_critic, epoch=0)
        train_envs_common = [e for e in train_envs if 6000 <= e < 7000]
        train_envs_sim = [e for e in train_envs if e < 6000]
        dev_envs_common = [e for e in dev_envs if 6000 <= e < 7000]
        dev_envs_sim = [e for e in dev_envs if e < 6000]
    else:
        trainer = TrainerBidomain(model_real, model_sim, model_critic, model_oracle_critic, epoch=0)

    print("Beginning training...")
    best_test_loss = 1000
    for epoch in range(num_epochs):
        if bidata:
            train_loss = trainer.train_epoch(env_list_common=train_envs_common, env_list_sim=train_envs_sim, eval=False)
            test_loss = trainer.train_epoch(env_list_common=dev_envs_common, env_list_sim=dev_envs_sim, eval=True)
        else:
            train_loss = trainer.train_epoch(env_list=train_envs, eval=False)
            test_loss = trainer.train_epoch(env_list=dev_envs, eval=True)

        if test_loss < best_test_loss:
            best_test_loss = test_loss
            save_pytorch_model(model_real, real_filename)
            save_pytorch_model(model_sim, sim_filename)
            save_pytorch_model(model_critic, critic_filename)
            print(f"Saved models in: \n Real: {real_filename} \n Sim: {sim_filename} \n Critic: {critic_filename}")

        print ("Epoch", epoch, "train_loss:", train_loss, "test_loss:", test_loss)
        save_pytorch_model(model_real, f"tmp/{real_filename}_epoch_{epoch}")
        save_pytorch_model(model_sim, f"tmp/{sim_filename}_epoch_{epoch}")
        save_pytorch_model(model_critic, f"tmp/{critic_filename}_epoch_{epoch}")