Ejemplo n.º 1
0
    def __init__(self, run_name="",
                 aux_class_features=False, aux_grounding_features=False, aux_lang=False, recurrence=False):

        super(ModelGSFPV, self).__init__()
        self.model_name = "gs_fpv" + "_mem" if recurrence else ""
        self.run_name = run_name
        self.writer = LoggingSummaryWriter(log_dir="runs/" + run_name)

        self.params = get_current_parameters()["Model"]
        self.aux_weights = get_current_parameters()["AuxWeights"]

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        # Auxiliary Objectives
        self.use_aux_class_features = aux_class_features
        self.use_aux_grounding_features = aux_grounding_features
        self.use_aux_lang = aux_lang
        self.use_recurrence = recurrence

        self.img_to_features_w = FPVToFPVMap(self.params["img_w"], self.params["img_h"],
                                             self.params["resnet_channels"], self.params["feature_channels"])

        self.lang_filter_gnd = MapLangSemanticFilter(self.params["emb_size"], self.params["feature_channels"], self.params["relevance_channels"])

        self.lang_filter_goal = MapLangSpatialFilter(self.params["emb_size"], self.params["relevance_channels"], self.params["goal_channels"])

        self.map_downsample = DownsampleResidual(self.params["map_to_act_channels"], 2)

        self.recurrence = RecurrentEmbedding(self.params["gs_fpv_feature_map_size"], self.params["gs_fpv_recurrence_size"])

        # Sentence Embedding
        self.sentence_embedding = SentenceEmbeddingSimple(
            self.params["word_emb_size"], self.params["emb_size"], self.params["emb_layers"])

        in_features_size = self.params["gs_fpv_feature_map_size"] + self.params["emb_size"]
        if self.use_recurrence:
            in_features_size += self.params["gs_fpv_recurrence_size"]

        self.features_to_action = DenseMlpBlock2(in_features_size, self.params["mlp_hidden"], 4)

        # Auxiliary Objectives
        # --------------------------------------------------------------------------------------------------------------

        self.add_auxiliary(ClassAuxiliary2D("aux_class", None,  self.params["feature_channels"], self.params["num_landmarks"],
                                                "fpv_features", "lm_pos_fpv", "lm_indices"))
        self.add_auxiliary(ClassAuxiliary2D("aux_ground", None, self.params["relevance_channels"], 2,
                                                "fpv_features_g", "lm_pos_fpv", "lm_mentioned"))
        if self.params["templates"]:
            self.add_auxiliary(ClassAuxiliary("aux_lang_lm", self.params["emb_size"], self.params["num_landmarks"], 1,
                                              "sentence_embed", "lm_mentioned_tplt"))
            self.add_auxiliary(ClassAuxiliary("aux_lang_side", self.params["emb_size"], self.params["num_sides"], 1,
                                              "sentence_embed", "side_mentioned_tplt"))
        else:
            self.add_auxiliary(ClassAuxiliary("aux_lang_lm_nl", self.params["emb_size"], 2, self.params["num_landmarks"],
                                                "sentence_embed", "lang_lm_mentioned"))

        self.action_loss = ActionLoss()

        self.env_id = None
        self.prev_instruction = None
        self.seq_step = 0
Ejemplo n.º 2
0
    def __init__(self, run_name="", model_class=MODEL_RSS,
                 aux_class_features=False, aux_grounding_features=False,
                 aux_class_map=False, aux_grounding_map=False, aux_goal_map=False,
                 aux_lang=False, aux_traj=False, rot_noise=False, pos_noise=False):

        super(ModelTrajectoryTopDown, self).__init__()
        self.model_name = "sm_trajectory" + str(model_class)
        self.model_class = model_class
        print("Init model of type: ", str(model_class))
        self.run_name = run_name
        self.writer = LoggingSummaryWriter(log_dir="runs/" + run_name)

        self.params = get_current_parameters()["Model"]
        self.aux_weights = get_current_parameters()["AuxWeights"]

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        # Auxiliary Objectives
        self.use_aux_class_features = aux_class_features
        self.use_aux_grounding_features = aux_grounding_features
        self.use_aux_class_on_map = aux_class_map
        self.use_aux_grounding_on_map = aux_grounding_map
        self.use_aux_goal_on_map = aux_goal_map
        self.use_aux_lang = aux_lang
        self.use_aux_traj_on_map = aux_traj
        self.use_aux_reg_map = self.aux_weights["regularize_map"]

        self.use_rot_noise = rot_noise
        self.use_pos_noise = pos_noise


        # Path-pred FPV model definition
        # --------------------------------------------------------------------------------------------------------------

        self.img_to_features_w = FPVToGlobalMap(
            source_map_size=self.params["global_map_size"], world_size_px=self.params["world_size_px"], world_size=self.params["world_size_m"],
            res_channels=self.params["resnet_channels"], map_channels=self.params["feature_channels"],
            img_w=self.params["img_w"], img_h=self.params["img_h"], img_dbg=IMG_DBG)

        self.map_accumulator_w = LeakyIntegratorGlobalMap(source_map_size=self.params["global_map_size"], world_in_map_size=self.params["world_size_px"])

        # Pre-process the accumulated map to do language grounding if necessary - in the world reference frame
        if self.use_aux_grounding_on_map and not self.use_aux_grounding_features:
            self.map_processor_a_w = LangFilterMapProcessor(
                source_map_size=self.params["global_map_size"],
                world_size=self.params["world_size_px"],
                embed_size=self.params["emb_size"],
                in_channels=self.params["feature_channels"],
                out_channels=self.params["relevance_channels"],
                spatial=False, cat_out=True)
        else:
            self.map_processor_a_w = IdentityMapProcessor(source_map_size=self.params["global_map_size"], world_size=self.params["world_size_px"])

        if self.use_aux_goal_on_map:
            self.map_processor_b_r = LangFilterMapProcessor(source_map_size=self.params["local_map_size"],
                                                            world_size=self.params["world_size_px"],
                                                            embed_size=self.params["emb_size"],
                                                            in_channels=self.params["relevance_channels"],
                                                            out_channels=self.params["goal_channels"],
                                                            spatial=True, cat_out=True)
        else:
            self.map_processor_b_r = IdentityMapProcessor(source_map_size=self.params["local_map_size"],
                                                          world_size=self.params["world_size_px"])

        pred_channels = self.params["goal_channels"] + self.params["relevance_channels"]

        # Common
        # --------------------------------------------------------------------------------------------------------------

        # Sentence Embedding
        self.sentence_embedding = SentenceEmbeddingSimple(
            self.params["word_emb_size"], self.params["emb_size"], self.params["emb_layers"])

        self.map_transform_w_to_r = MapTransformerBase(source_map_size=self.params["global_map_size"],
                                                       dest_map_size=self.params["local_map_size"],
                                                       world_size=self.params["world_size_px"])
        self.map_transform_r_to_w = MapTransformerBase(source_map_size=self.params["local_map_size"],
                                                       dest_map_size=self.params["global_map_size"],
                                                       world_size=self.params["world_size_px"])

        # Batch select is used to drop and forget semantic maps at those timestaps that we're not planning in
        self.batch_select = MapBatchSelect()
        # Since we only have path predictions for some timesteps (the ones not dropped above), we use this to fill
        # in the missing pieces by reorienting the past trajectory prediction into the frame of the current timestep
        self.map_batch_fill_missing = MapBatchFillMissing(self.params["local_map_size"], self.params["world_size_px"])

        # Passing true to freeze will freeze these weights regardless of whether they've been explicitly reloaded or not
        enable_weight_saving(self.sentence_embedding, "sentence_embedding", alwaysfreeze=False)

        # Output an action given the global semantic map
        if self.params["map_to_action"] == "downsample2":
            self.map_to_action = EgoMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"],
                other_features_size=self.params["emb_size"])

        elif self.params["map_to_action"] == "cropped":
            self.map_to_action = CroppedMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"],
                other_features_size=self.params["emb_size"]
            )

        # Don't freeze the trajectory to action weights, because it will be pre-trained during path-prediction training
        # and finetuned on all timesteps end-to-end
        enable_weight_saving(self.map_to_action, "map_to_action", alwaysfreeze=False, neverfreeze=True)

        # Auxiliary Objectives
        # --------------------------------------------------------------------------------------------------------------

        # We add all auxiliaries that are necessary. The first argument is the auxiliary name, followed by parameters,
        # followed by variable number of names of inputs. ModuleWithAuxiliaries will automatically collect these inputs
        # that have been saved with keep_auxiliary_input() during execution
        if aux_class_features:
            self.add_auxiliary(ClassAuxiliary2D("aux_class", None,  self.params["feature_channels"], self.params["num_landmarks"], self.params["dropout"],
                                                "fpv_features", "lm_pos_fpv", "lm_indices"))
        if aux_grounding_features:
            self.add_auxiliary(ClassAuxiliary2D("aux_ground", None, self.params["relevance_channels"], 2, self.params["dropout"],
                                                "fpv_features_g", "lm_pos_fpv", "lm_mentioned"))
        if aux_class_map:
            self.add_auxiliary(ClassAuxiliary2D("aux_class_map", self.params["world_size_px"], self.params["feature_channels"], self.params["num_landmarks"], self.params["dropout"],
                                                "map_s_w_select", "lm_pos_map_select", "lm_indices_select"))
        if aux_grounding_map:
            self.add_auxiliary(ClassAuxiliary2D("aux_grounding_map", self.params["world_size_px"], self.params["relevance_channels"], 2, self.params["dropout"],
                                                "map_a_w_select", "lm_pos_map_select", "lm_mentioned_select"))
        if aux_goal_map:
            self.add_auxiliary(GoalAuxiliary2D("aux_goal_map", self.params["goal_channels"], self.params["world_size_px"],
                                               "map_b_w", "goal_pos_map"))
        # RSS model uses templated data for landmark and side prediction
        if self.use_aux_lang and self.params["templates"]:
            self.add_auxiliary(ClassAuxiliary("aux_lang_lm", self.params["emb_size"], self.params["num_landmarks"], 1,
                                                "sentence_embed", "lm_mentioned_tplt"))
            self.add_auxiliary(ClassAuxiliary("aux_lang_side", self.params["emb_size"], self.params["num_sides"], 1,
                                                "sentence_embed", "side_mentioned_tplt"))
        # CoRL model uses alignment-model groundings
        elif self.use_aux_lang:
            # one output for each landmark, 2 classes per output. This is for finetuning, so use the embedding that's gonna be fine tuned
            self.add_auxiliary(ClassAuxiliary("aux_lang_lm_nl", self.params["emb_size"], 2, self.params["num_landmarks"],
                                                "sentence_embed", "lang_lm_mentioned"))
        if self.use_aux_traj_on_map:
            self.add_auxiliary(PathAuxiliary2D("aux_path", "map_b_r_select", "traj_gt_r_select"))

        if self.use_aux_reg_map:
            self.add_auxiliary(FeatureRegularizationAuxiliary2D("aux_regularize_features", None, "l1",
                                                                "map_s_w_select", "lm_pos_map_select"))

        self.goal_good_criterion = GoalPredictionGoodCriterion(ok_distance=3.2)
        self.goal_acc_meter = MovingAverageMeter(10)

        self.print_auxiliary_info()

        self.action_loss = ActionLoss()

        self.env_id = None
        self.prev_instruction = None
        self.seq_step = 0
Ejemplo n.º 3
0
def train_rl_worker(sup_process_conn):
    P.initialize_experiment()
    setup = P.get_current_parameters()["Setup"]
    setup["trajectory_length"] = setup["rl_trajectory_length"]
    run_name = setup["run_name"]
    rlsup = P.get_current_parameters()["RLSUP"]
    params = P.get_current_parameters()["RL"]
    num_rl_epochs = params["num_epochs"]
    # These need to be distinguished between supervised and RL because supervised trains on ALL envs, RL only on 6000-7000
    setup["env_range_start"] = setup["rl_env_range_start"]
    setup["env_range_end"] = setup["rl_env_range_end"]
    rl_device = rlsup.get("rl_device", "cuda:0")

    trainer = TrainerRL(params=dict_merge(setup, params),
                        save_rollouts_to_dataset=rl_dataset_name(run_name),
                        device=rl_device)

    # -------------------------------------------------------------------------------------
    # TODO: Continue (including figure out how to initialize Supervised Stage 1 real/sim/critic and RL Stage 2 policy
    start_rl_epoch = 0
    for start_rl_epoch in range(num_rl_epochs):
        epfname = epoch_rl_filename(run_name, start_rl_epoch, model="full")
        path = os.path.join(get_model_dir(), str(epfname) + ".pytorch")
        if not os.path.exists(path):
            break
    if start_rl_epoch > 0:
        print(f"RLP: CONTINUING RL TRAINING FROM EPOCH: {start_rl_epoch}")
        load_pytorch_model(
            trainer.full_model,
            epoch_rl_filename(run_name, start_rl_epoch - 1, model="full"))
        trainer.set_start_epoch(start_rl_epoch)
    # Wait for supervised process to send it's model
    sleep(2)

    # -------------------------------------------------------------------------------------

    print("RLP: Beginning training...")
    for rl_epoch in range(start_rl_epoch, num_rl_epochs):
        # Get the latest Stage 1 model. Halt on the first epoch so that we can actually initialize the Stage 1
        new_stage1_model_state_dict = receive_stage1_state(
            sup_process_conn, halt=(rl_epoch == start_rl_epoch))
        if new_stage1_model_state_dict:
            print(f"RLP: Re-loading latest Stage 1 model")
            trainer.reload_stage1(new_stage1_model_state_dict)

        train_reward, metrics = trainer.train_epoch(epoch_num=rl_epoch,
                                                    eval=False,
                                                    envs="train")
        dev_reward, metrics = trainer.train_epoch(epoch_num=rl_epoch,
                                                  eval=True,
                                                  envs="dev")

        print("RLP: RL Epoch", rl_epoch, "train reward:", train_reward,
              "dev reward:", dev_reward)
        save_pytorch_model(trainer.full_model,
                           epoch_rl_filename(run_name, rl_epoch, model="full"))
        save_pytorch_model(
            trainer.full_model.stage1_visitation_prediction,
            epoch_rl_filename(run_name, rl_epoch, model="stage1"))
        save_pytorch_model(
            trainer.full_model.stage2_action_generation,
            epoch_rl_filename(run_name, rl_epoch, model="stage2"))
Ejemplo n.º 4
0
Archivo: paths.py Proyecto: dxsun/drif
def get_config_base_dir():
    base_dir = get_current_parameters()["Environment"]["config_dir"]
    return base_dir
Ejemplo n.º 5
0
def train_dagger():
    P.initialize_experiment()
    global PARAMS
    PARAMS = P.get_current_parameters()["Dagger"]
    setup = P.get_current_parameters()["Setup"]
    dataset_name = P.get_current_parameters()["Data"]["dataset_name"]

    if setup["num_workers"] > 1:
        roller = ParallelPolicyRoller(num_workers=setup["num_workers"], first_worker=setup["first_worker"], reduce=PARAMS["segment_level"])
    else:
        roller = PolicyRoller()

    latest_model_filename = "dagger_" + setup["model"] + "_" + setup["run_name"]
    dagger_data_dir = "dagger_data/" + setup["run_name"] + "/"

    save_json(PARAMS, dagger_data_dir + "run_params.json")

    # Load less tf data, but sample dagger rollouts from more environments to avoid overfitting.
    train_envs, dev_envs, test_envs = data_io.instructions.get_restricted_env_id_lists(max_envs=PARAMS["max_envs_dag"])

    if PARAMS["resample_supervised_data"]:
        # Supervised data are represented as integers that will be later loaded by the dataset
        all_train_data = list(range(PARAMS["max_samples_in_memory"]))
        all_test_data = list(range(0))
    else:
        all_train_data, all_test_data = data_io.train_data.load_supervised_data(dataset_name, max_envs=PARAMS["max_envs_sup"], split_segments=PARAMS["segment_level"])

    resample_supervised_data(dataset_name, all_train_data, train_envs)
    resample_supervised_data(dataset_name, all_test_data, test_envs)

    print("Loaded tf data size: " + str(len(all_train_data)) + " : " + str(len(all_test_data)))

    model = load_dagger_model(latest_model_filename)
    data_io.model_io.save_pytorch_model(model, latest_model_filename)

    if PARAMS["restore_latest"]:
        all_train_data, all_test_data = restore_data_latest(dagger_data_dir, dataset_name)
    else:
        restore_data(dataset_name, dagger_data_dir, all_train_data, all_test_data)

    last_trainer_state = None

    for iteration in range(PARAMS["restore"], PARAMS["max_iterations"]):
        gc.collect()
        print("-------------------------------")
        print("DAGGER ITERATION : ", iteration)
        print("-------------------------------")

        test_data_i = all_test_data

        # If we have too many training examples in memory, discard uniformly at random to keep a somewhat fixed bound
        max_samples = PARAMS["max_samples_in_memory"]
        if max_samples > 0 and len(all_train_data) > max_samples:# and iteration != args.dagger_restore:
            num_discard = len(all_train_data) - max_samples
            print("Too many samples in memory! Dropping " + str(num_discard) + " samples")
            discards = set(random.sample(list(range(len(all_train_data))), num_discard))
            all_train_data = [sample for i, sample in enumerate(all_train_data) if i not in discards]
            print("Now left " + str(len(all_train_data)) + " samples")

        # Roll out new data at iteration i, except if we are restoring to that iteration, in which case we already have data
        if iteration != PARAMS["restore"] or iteration == 0:
            train_data_i, test_data_i = collect_iteration_data(roller, iteration, train_envs, test_envs, latest_model_filename, dagger_data_dir, dataset_name)

            # Aggregate the dataset
            all_train_data += train_data_i
            all_test_data += test_data_i
            print("Aggregated dataset!)")
            print("Total samples: ", len(all_train_data))
            print("New samples: ", len(train_data_i))

        data_io.train_data.save_dataset(dataset_name, all_train_data, dagger_data_dir + "train_latest")
        data_io.train_data.save_dataset(dataset_name, test_data_i, dagger_data_dir + "test_latest")

        model, model_loaded = load_latest_model(latest_model_filename)

        trainer = Trainer(model, state=last_trainer_state)

        import rollout.run_metadata as run_md
        run_md.IS_ROLLOUT = False

        # Train on the newly aggregated dataset
        num_epochs = PARAMS["epochs_per_iteration_override"][iteration] if iteration in PARAMS["epochs_per_iteration_override"] else PARAMS["epochs_per_iteration"]
        for epoch in range(num_epochs):

            # Get a random sample of all test data for calculating eval loss
            #epoch_test_sample = sample_n_from_list(all_test_data, PARAMS["num_test_samples"])
            # Just evaluate on the latest test data
            epoch_test_sample = test_data_i

            loss = trainer.train_epoch(all_train_data)
            test_loss = trainer.train_epoch(epoch_test_sample, eval=True)

            data_io.model_io.save_pytorch_model(trainer.model, latest_model_filename)
            print("Epoch", epoch, "Loss: Train:", loss, "Test:", test_loss)

        data_io.model_io.save_pytorch_model(trainer.model, get_model_filename_at_iteration(setup, iteration))
        if hasattr(trainer.model, "save"):
            trainer.model.save("dag" + str(iteration))
        last_trainer_state = trainer.get_state()
Ejemplo n.º 6
0
def get_all_instructions(max_size=0, do_prune_ambiguous=False, full=False):
    #print("max_size:", max_size)

    # If instructions already loaded in memory, return them
    global cache
    global loaded_corpus
    global loaded_size

    if full:
        min_augment_len = 1
    else:
        min_augment_len = P.get_current_parameters()["Setup"].get(
            "min_augment_len", 1)
    max_augment_len = P.get_current_parameters()["Setup"].get("augment_len", 1)

    train_key = f"train-{min_augment_len}-{max_augment_len}"
    dev_key = f"dev-{min_augment_len}-{max_augment_len}"
    test_key = f"test-{min_augment_len}-{max_augment_len}"

    if cache is not None and train_key in cache:  # loaded_size == max_size:
        train_instructions = cache[train_key]
        dev_instructions = cache[dev_key]
        test_instructions = cache[test_key]
        corpus = loaded_corpus

    # Otherwise see if they've been pre-build in tmp files
    else:
        # Cache
        cache_dir = get_instruction_cache_dir()
        corpus_dir = get_config_dir()

        train_file = os.path.join(
            cache_dir, f"train_{min_augment_len}-{max_augment_len}.json")
        dev_file = os.path.join(
            cache_dir, f"dev_{min_augment_len}-{max_augment_len}.json")
        test_file = os.path.join(
            cache_dir, f"test_{min_augment_len}-{max_augment_len}.json")
        corpus_file = os.path.join(corpus_dir, "corpus.json")
        wfreq_file = os.path.join(corpus_dir, "word_freq.json")

        corpus_already_exists = False
        if os.path.isfile(corpus_file):
            with open(corpus_file, "r") as f:
                corpus = list(json.load(f))
                #print("corpus: ", len(corpus))
            corpus_already_exists = True

        # If they have been saved in tmp files, load them
        if os.path.isfile(train_file):
            train_instructions = load_instruction_data_from_json(train_file)
            dev_instructions = load_instruction_data_from_json(dev_file)
            test_instructions = load_instruction_data_from_json(test_file)
            assert corpus_already_exists, "Insruction data exists but corpus is gone!"

        # Otherwise rebuild instruction data from annotations
        else:
            print(
                f"REBUILDING INSTRUCTION DATA FOR SEGMENT LENGTHS: {min_augment_len} to {max_augment_len}!"
            )
            print(f"USING OLD CORPUS: {corpus_already_exists}")
            os.makedirs(cache_dir, exist_ok=True)

            all_instructions, new_corpus = defaultdict(list), set()

            train_an, dev_an, test_an = load_train_dev_test_annotations()

            print("Loaded JSON Data")

            print("Parsing dataset")
            print("    train...")
            train_instructions, new_corpus, word_freq = parse_dataset(
                train_an, new_corpus)
            print("    dev...")
            dev_instructions, new_corpus, _ = parse_dataset(dev_an, new_corpus)
            print("    test...")
            test_instructions, new_corpus, _ = parse_dataset(
                test_an, new_corpus)

            print("Augmenting maybe?")
            train_instructions = augment_dataset(train_instructions,
                                                 merge_len=max_augment_len,
                                                 min_merge_len=min_augment_len)
            dev_instructions = augment_dataset(dev_instructions,
                                               merge_len=max_augment_len,
                                               min_merge_len=min_augment_len)
            test_instructions = augment_dataset(test_instructions,
                                                merge_len=max_augment_len,
                                                min_merge_len=min_augment_len)

            save_json(train_instructions, train_file)
            save_json(dev_instructions, dev_file)
            save_json(test_instructions, test_file)

            if not corpus_already_exists:
                corpus = new_corpus
                save_json(list(corpus), corpus_file)
                save_json(word_freq, wfreq_file)
            else:
                print("Warning! Regenerated pomdp, but kept the old corpus!")

            print("Saved instructions for quicker loading!")

    # Clip datasets to the provided size
    if max_size is not None and max_size > 0:
        num_train = int(math.ceil(max_size * 0.7))
        num_dev = int(math.ceil(max_size * 0.15))
        num_test = int(math.ceil(max_size * 0.15))

        train_instructions = slice_list_tail(train_instructions, num_train)
        dev_instructions = slice_list_tail(dev_instructions, num_dev)
        test_instructions = slice_list_tail(test_instructions, num_test)

    if do_prune_ambiguous:
        train_instructions = prune_ambiguous(train_instructions)
        dev_instructions = prune_ambiguous(dev_instructions)
        test_instructions = prune_ambiguous(test_instructions)

    #print("Corpus: ", len(corpus))
    #print("Loaded: ", len(train_instructions), len(dev_instructions), len(test_instructions))
    if cache is None:
        cache = {}

    cache[train_key] = train_instructions
    cache[dev_key] = dev_instructions
    cache[test_key] = test_instructions
    loaded_corpus = corpus
    loaded_size = max_size

    return train_instructions, dev_instructions, test_instructions, corpus
Ejemplo n.º 7
0
def train_top_down_pred():
    P.initialize_experiment()
    setup = P.get_current_parameters()["Setup"]
    launch_ui()

    env = PomdpInterface()

    model, model_loaded = load_model(model_name_override=setup["top_down_model"],
                                     model_file_override=setup["top_down_model_file"])

    exec_model, wrapper_model_loaded = load_model(model_name_override=setup["wrapper_model"],
                                                  model_file_override=setup["wrapper_model_file"])

    affine2d = Affine2D()
    if model.is_cuda:
        affine2d.cuda()

    eval_envs = get_correct_eval_env_id_list()
    train_instructions, dev_instructions, test_instructions, corpus = get_all_instructions(max_size=setup["max_envs"])
    all_instr = {**train_instructions, **dev_instructions, **train_instructions}
    token2term, word2token = get_word_to_token_map(corpus)

    dataset = model.get_dataset(envs=eval_envs, dataset_name="supervised", eval=True, seg_level=False)
    dataloader = DataLoader(
        dataset,
        collate_fn=dataset.collate_fn,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        pin_memory=True)

    for b, batch in enumerate(dataloader):

        images = batch["images"]
        instructions = batch["instr"]
        label_masks = batch["traj_labels"]
        affines = batch["affines_g_to_s"]
        env_ids = batch["env_id"]
        set_idxs = batch["set_idx"]
        seg_idxs = batch["seg_idx"]

        env_id = env_ids[0][0]
        set_idx = set_idxs[0][0]
        env.set_environment(env_id, instruction_set=all_instr[env_id][set_idx]["instructions"])
        env.reset(0)

        num_segments = len(instructions[0])

        write_instruction("")
        write_real_instruction("None")
        instruction_str = read_instruction_file()
        print("Initial instruction: ", instruction_str)

        # TODO: Reset model state here if we keep any temporal memory etc
        for s in range(num_segments):
            start_state = env.reset(s)
            keep_going = True
            real_instruction = cuda_var(instructions[0][s], setup["cuda"], 0)
            tmp = list(real_instruction.data.cpu()[0].numpy())
            real_instruction_str = debug_untokenize_instruction(tmp)
            write_real_instruction(real_instruction_str)
            #write_instruction(real_instruction_str)
            #instruction_str = real_instruction_str
            image = cuda_var(images[0][s], setup["cuda"], 0)
            label_mask = cuda_var(label_masks[0][s], setup["cuda"], 0)
            affine_g_to_s = affines[0][s]

            while keep_going:
                write_real_instruction(real_instruction_str)

                while True:
                    cv2.waitKey(200)
                    instruction = read_instruction_file()
                    if instruction == "CMD: Next":
                        print("Advancing")
                        keep_going = False
                        write_empty_instruction()
                        break
                    elif instruction == "CMD: Reset":
                        print("Resetting")
                        env.reset(s)
                        write_empty_instruction()
                    elif len(instruction.split(" ")) > 1:
                        instruction_str = instruction
                        print("Executing: ", instruction_str)
                        break

                if not keep_going:
                    continue

                #instruction_str = read_instruction_file()
                # TODO: Load instruction from file
                tok_instruction = tokenize_instruction(instruction_str, word2token)
                instruction_t = torch.LongTensor(tok_instruction).unsqueeze(0)
                instruction_v = cuda_var(instruction_t, setup["cuda"], 0)
                instruction_mask = torch.ones_like(instruction_v)
                tmp = list(instruction_t[0].numpy())
                instruction_dbg_str = debug_untokenize_instruction(tmp, token2term)

                res = model(image, instruction_v, instruction_mask)
                mask_pred = res[0]
                shp = mask_pred.shape
                mask_pred = F.softmax(mask_pred.view([2, -1]), 1).view(shp)
                #mask_pred = softmax2d(mask_pred)

                # TODO: Rotate the mask_pred to the global frame
                affine_s_to_g = np.linalg.inv(affine_g_to_s)
                S = 8.0
                affine_scale_up = np.asarray([[S, 0, 0],
                                             [0, S, 0],
                                              [0, 0, 1]])
                affine_scale_down = np.linalg.inv(affine_scale_up)

                affine_pred_to_g = np.dot(affine_scale_down, np.dot(affine_s_to_g, affine_scale_up))
                #affine_pred_to_g_t = torch.from_numpy(affine_pred_to_g).float()

                mask_pred_np = mask_pred.data.cpu().numpy()[0].transpose(1, 2, 0)
                mask_pred_g_np = apply_affine(mask_pred_np, affine_pred_to_g, 32, 32)
                print("Sum of global mask: ", mask_pred_g_np.sum())
                mask_pred_g = torch.from_numpy(mask_pred_g_np.transpose(2, 0, 1)).float()[np.newaxis, :, :, :]
                exec_model.set_ground_truth_visitation_d(mask_pred_g)

                # Create a batch axis for pytorch
                #mask_pred_g = affine2d(mask_pred, affine_pred_to_g_t[np.newaxis, :, :])

                mask_pred_np[:, :, 0] -= mask_pred_np[:, :, 0].min()
                mask_pred_np[:, :, 0] /= (mask_pred_np[:, :, 0].max() + 1e-9)
                mask_pred_np[:, :, 0] *= 2.0
                mask_pred_np[:, :, 1] -= mask_pred_np[:, :, 1].min()
                mask_pred_np[:, :, 1] /= (mask_pred_np[:, :, 1].max() + 1e-9)

                presenter = Presenter()
                #presenter.show_image(mask_pred_g_np, "mask_pred_g", torch=False, waitkey=1, scale=4)
                pred_viz_np = presenter.overlaid_image(image.data, mask_pred_np, channel=0)
                # TODO: Don't show labels
                # TODO: OpenCV colours
                #label_mask_np = p.data.cpu().numpy()[0].transpose(1,2,0)

                labl_viz_np = presenter.overlaid_image(image.data, label_mask.data, channel=0)
                viz_img_np = np.concatenate((pred_viz_np, labl_viz_np), axis=1)
                viz_img_np = pred_viz_np

                viz_img = presenter.overlay_text(viz_img_np, instruction_dbg_str)
                cv2.imshow("interactive viz", viz_img)
                cv2.waitKey(100)

                rollout_model(exec_model, env, env_ids[0][s], set_idxs[0][s], seg_idxs[0][s], tok_instruction)
                write_instruction("")
Ejemplo n.º 8
0
def get_sim_config_dir():
    return get_current_parameters()["Environment"]["sim_config_dir"]
Ejemplo n.º 9
0
def train_dagger_simple():
    # ----------------------------------------------------------------------------------------------------------------
    # Load params and configure stuff

    P.initialize_experiment()
    params = P.get_current_parameters()["SimpleDagger"]
    setup = P.get_current_parameters()["Setup"]
    num_iterations = params["num_iterations"]
    sim_seed_dataset = params.get("sim_seed_dataset")
    run_name = setup["run_name"]
    device = params.get("device", "cuda:1")
    dataset_limit = params.get("dataset_size_limit_envs")
    seed_count = params.get("seed_count")

    # Trigger rebuild if necessary before going into all the threads and processes
    _ = get_restricted_env_id_lists(full=True)

    # Initialize the dataset
    if sim_seed_dataset:
        copy_seed_dataset(from_dataset=sim_seed_dataset,
                          to_dataset=dagger_dataset_name(run_name),
                          seed_count=seed_count or dataset_limit)
        gap = 0
    else:
        # TODO: Refactor this into a prompt function
        data_path = get_dataset_dir(dagger_dataset_name(run_name))
        if os.path.exists(data_path):
            print("DATASET EXISTS! Continue where left off?")
            c = input(" (y/n) >>> ")
            if c != "y":
                raise ValueError(
                    f"Not continuing: Dataset {data_path} exists. Delete it if you like and try again"
                )
        else:
            os.makedirs(data_path, exist_ok=True)
        gap = dataset_limit - len(os.listdir(data_path))

    print("SUPP: Loading data")
    train_envs, dev_envs, test_envs = get_restricted_env_id_lists()

    # ----------------------------------------------------------------------------------------------------------------
    # Load / initialize model

    model = load_model(setup["model"], setup["model_file"],
                       domain="sim")[0].to(device)
    oracle = load_model("oracle")[0]

    # ----------------------------------------------------------------------------------------------------------------
    # Continue where we left off - load the model and set the iteration/epoch number

    for start_iteration in range(10000):
        epfname = epoch_dag_filename(run_name, start_iteration)
        path = os.path.join(get_model_dir(), str(epfname) + ".pytorch")
        if not os.path.exists(path):
            break
    if start_iteration > 0:
        print(
            f"DAG: CONTINUING DAGGER TRAINING FROM ITERATION: {start_iteration}"
        )
        load_pytorch_model(model,
                           epoch_dag_filename(run_name, start_iteration - 1))

    # ----------------------------------------------------------------------------------------------------------------
    # Intialize trainer

    trainer = Trainer(model,
                      epoch=start_iteration,
                      name=setup["model"],
                      run_name=setup["run_name"])
    trainer.set_dataset_names([dagger_dataset_name(run_name)])

    # ----------------------------------------------------------------------------------------------------------------
    # Initialize policy roller

    roller = SimpleParallelPolicyRoller(
        num_workers=params["num_workers"],
        device=params["device"],
        policy_name=setup["model"],
        policy_file=setup["model_file"],
        oracle=oracle,
        dataset_save_name=dagger_dataset_name(run_name),
        no_reward=True)
    rollout_sampler = RolloutSampler(roller)

    # ----------------------------------------------------------------------------------------------------------------
    # Train DAgger - loop over iteartions, in each, prune, rollout and train an epoch

    print("SUPP: Beginning training...")
    for iteration in range(start_iteration, num_iterations):
        print(f"DAG: Starting iteration {iteration}")

        # Remove extra rollouts to keep within DAggerFM limit
        prune_dataset(run_name, dataset_limit)

        # Rollout and collect more data for training and evaluation
        policy_state = model.get_policy_state()
        rollout_sampler.sample_n_rollouts(
            n=gap if iteration == 0 else params["train_envs_per_iteration"],
            policy_state=policy_state,
            sample=False,
            envs="train",
            dagger_beta=dagger_beta(params, iteration))

        eval_rollouts = rollout_sampler.sample_n_rollouts(
            n=params["eval_envs_per_iteration"],
            policy_state=policy_state,
            sample=False,
            envs="dev",
            dagger_beta=0)

        # Kill airsim instances so that they don't take up GPU memory and in general slow things down during training
        roller.kill_airsim()

        # Evaluate success / metrics and save to tensorboard
        if setup["eval_nl"]:
            evaler = DataEvalNL(run_name,
                                entire_trajectory=False,
                                save_images=False)
            evaler.evaluate_dataset(eval_rollouts)
            results = evaler.get_results()
            print("Results:", results)
            evaler.write_summaries(setup["run_name"], "dagger_eval", iteration)

        # Do one epoch of supervised training
        print("SUPP: Beginning Epoch")
        train_loss = trainer.train_epoch(train_envs=train_envs, eval=False)
        #test_loss = trainer.train_epoch(env_list_common=dev_envs_common, env_list_sim=dev_envs_sim, eval=True)

        # Save the model to file
        print("SUPP: Epoch", iteration, "train_loss:", train_loss)
        save_pytorch_model(model, epoch_dag_filename(run_name, iteration))
Ejemplo n.º 10
0
 def _read_clock_speed(self):
     speed = 1.0
     if "ClockSpeed" in P.get_current_parameters()["AirSim"]:
         speed = P.get_current_parameters()["AirSim"]["ClockSpeed"]
     print("Read clock speed: " + str(speed))
     return speed
Ejemplo n.º 11
0
    def evaluate_rollout(self, rollout):
        last_sample = rollout[-1]
        if "metadata" not in last_sample:
            last_sample["metadata"] = last_sample
        env_id = last_sample["metadata"]["env_id"]
        # TEMPORARY FOR APPENDIX TABLE! REMOVE IT!
        # if env_id >= 6000:
        #    return None
        seg_idx = last_sample["metadata"]["seg_idx"]
        set_idx = last_sample["metadata"]["set_idx"]

        path = load_and_convert_path(env_id)

        seg_ordinal = seg_idx_to_ordinal(
            self.all_i[env_id][set_idx]["instructions"], seg_idx)
        instr_seg = self.all_i[env_id][set_idx]["instructions"][seg_ordinal]

        if self.entire_trajectory:
            path_end_idx = len(path) - 1
            path_start_idx = 0
        else:
            # Find the segment end index
            path_end_idx = self.all_i[env_id][set_idx]["instructions"][
                seg_ordinal]["end_idx"] + 1
            path_start_idx = self.all_i[env_id][set_idx]["instructions"][
                seg_ordinal]["start_idx"]
            if path_end_idx > len(path) - 1:
                path_end_idx = len(path) - 1
            if path_end_idx < path_start_idx:
                path_start_idx = path_end_idx

        seg_path = path[path_start_idx:path_end_idx]
        goal_visible = self.is_goal_visible(instr_seg)
        self.visible_map[f"{env_id}_{seg_idx}"] = (1 if goal_visible else 0)
        exec_path = np.asarray([r["state"].get_pos_2d() for r in rollout])

        end_pos = np.asarray(exec_path[-1])  #["state"].get_pos_2d())
        target_end_pos = np.asarray(seg_path[-1])
        end_dist = np.linalg.norm(end_pos - target_end_pos)
        success = end_dist < self.passing_distance

        # EMD between trajectories, and EMD between start position and trajectory.
        exec_path = self._filter_path(exec_path)
        gt_path = self._filter_path(seg_path)
        emd = self._calculate_emd(exec_path, gt_path)
        stop_emd = self._calculate_emd(exec_path[0:1], gt_path)

        # Success weighted by earth-mover's distance
        nemd = emd / stop_emd
        semd = max((1 if success else 0) * (1 - nemd), 0)

        if last_sample["metadata"]["pol_action"][3] > 0.5:
            who_stopped = "Policy Stopped"
        elif last_sample["metadata"]["ref_action"][3] > 0.5:
            who_stopped = "Oracle Stopped"
        else:
            who_stopped = "Veered Off"

        result = "Success" if success else "Fail"
        print(env_id, set_idx, seg_idx, result)

        texts = [who_stopped, result, "run:" + self.run_name]

        #print(seg_idx, result, semd)

        if self.save_images and emd:
            dir = get_results_dir(self.run_name, makedir=True)
            print("Results dir: ", dir)
            # TODO: Refactor this to not pull path from rollout, but provide it explicitly
            self.presenter.plot_paths(
                rollout,
                segment_path=gt_path,
                interactive=False,
                texts=texts,
                entire_trajectory=self.entire_trajectory,
                world_size=P.get_current_parameters()["Setup"]["world_size_m"],
                real_drone=P.get_current_parameters()["Setup"]["real_drone"])
            filename = os.path.join(
                dir,
                str(env_id) + "_" + str(set_idx) + "_" + str(seg_idx))
            if self.custom_instr is not None:
                filename += "_" + last_sample["metadata"][
                    "instruction"][:24] + "_" + last_sample["metadata"][
                        "instruction"][-16:]
            self.presenter.save_plot(filename)

        #if emd:
        #    self.save_results()

        return ResultsLandmarkSide(success=success,
                                   end_dist=end_dist,
                                   goal_visible=goal_visible,
                                   emd=emd,
                                   semd=semd,
                                   nemd=nemd)
Ejemplo n.º 12
0
 def rollout_begin(self, instruction):
     self.camcorder1.start_recording_rollout(P.get_current_parameters()["Setup"]["run_name"], self.env_id, 0, self.seg_idx, caption=instruction)
Ejemplo n.º 13
0
    def __init__(self,
                 data=None,
                 env_list=None,
                 dataset_names=["simulator"],
                 dataset_prefix="supervised",
                 domain="sim",
                 max_traj_length=None,
                 aux_provider_names=[],
                 segment_level=False,
                 cache=False):
        """
        Dataset for the replay memory
        :param data: if data is pre-loaded in memory, this is the training data
        :param env_list: if data is to be loaded by the dataset, this is the list of environments for which to include data
        :param dataset_names: list of datasets from which to load data
        :param dataset_prefix: name of the dataset. Default: supervised will use data collected with collect_supervised_data
        :param max_traj_length: truncate trajectories to this long
        :param cuda:
        :param aux_provider_names:
        """

        # If data is already loaded in memory, use it
        self.data = data
        self.prof = SimpleProfiler(torch_sync=False, print=PROFILE)
        self.min_seg_len = P.get_current_parameters()["Data"].get("min_seg_len", 3)
        self.do_cache = P.get_current_parameters()["Data"].get("cache", False)
        self.dataset_prefix = dataset_prefix
        self.dataset_names = dataset_names
        self.domain = domain

        self.env_restrictions = P.get_current_parameters()["Data"].get("dataset_env_restrictions")
        if self.env_restrictions:
            self.dataset_restricted_envs = {dname:P.get_current_parameters()["Data"]["EnvRestrictionGroups"][self.env_restrictions[dname]] for dname in dataset_names if dname in self.env_restrictions}
            print(f"Using restricted envs: {list(self.dataset_restricted_envs.keys())}")
        else:
            self.dataset_restricted_envs = {}

        self.max_traj_length = max_traj_length
        train_instr, dev_instr, test_instr, corpus = get_all_instructions()
        # TODO: This shouldn't have access to all instructions. We should really make distinct train, dev, test modes
        self.all_instr = {**train_instr, **dev_instr, **test_instr}

        train_instr_full, dev_instr_full, test_instr_full, corpus = get_all_instructions(full=True)
        self.all_instr_full = {**train_instr_full, **dev_instr_full, **test_instr_full}

        self.segment_level = segment_level
        self.sample_ids = []

        if self.data is None:
            assert env_list is not None
            for i, dataset_name in enumerate(self.dataset_names):
                dataset_env_list = filter_env_list_has_data(dataset_name, env_list, dataset_prefix)
                if self.segment_level:
                    dataset_env_list, dataset_seg_list = self.split_into_segments(dataset_env_list, dataset_name)
                else:
                    dataset_seg_list = [0 for _ in dataset_env_list]
                for env, seg in zip(dataset_env_list, dataset_seg_list):
                    self.sample_ids.append((dataset_name, env, seg))

        self.token2word, self.word2token = get_word_to_token_map(corpus)
        self.aux_provider_names = aux_provider_names
        self.aux_label_names = get_aux_label_names(aux_provider_names)
        self.stackable_names = get_stackable_label_names(aux_provider_names)
        self.data_cache = {dataset_name:{} for dataset_name in dataset_names}

        self.traj_len = P.get_current_parameters()["Setup"]["trajectory_length"]
Ejemplo n.º 14
0
 def _write_airsim_settings(self):
     airsim_settings = P.get_current_parameters()["AirSim"]
     airsim_settings_path = P.get_current_parameters()["Environment"]["airsim_settings_path"]
     airsim_settings_path = os.path.expanduser(airsim_settings_path)
     save_json(airsim_settings, airsim_settings_path)
     print("Wrote new AirSim settings to " + str(airsim_settings_path))
Ejemplo n.º 15
0
def generate_rollout_debug_visualizations():
    setup = P.get_current_parameters()["Setup"]

    dataset_name = setup.get("viz_dataset_name") or get_eval_tmp_dataset_name(setup["model"], setup["run_name"])
    domain = setup.get("viz_domain") or ("real" if setup.get("real_drone") else "sim")
    run_name = setup.get("original_run_name") or setup.get("run_name")
    specific_envs = setup.get("only_specific_envs")

    # For collecting information for visualization examples
    specific_segments = [
        # running example
        (6827, 0, 4),
        # successful examples
        (6169, 0, 9),
        (6825, 0, 8),
        (6857, 0, 9),
        # failure examples
        (6169, 0, 2),
        (6299, 0, 9),
        (6634, 0, 8),
        (6856, 0, 9),
        (6857, 0, 8),
    ]
    specific_segments += [
        # good sim, lousy real
        (6419, 0, 5),
        (6569, 0, 6),
        (6634, 0, 6),
        (6917, 0, 7),
    ]
    specific_envs = [s[0] for s in specific_segments]

    # Generate all
    #specific_envs = list(range(6000, 7000, 1))
    #specific_segments = None

    # Some quick params. TODO: Bring this into json
    viz_params = {
        "ego_vdist": False,
        "draw_landmarks": False,
        "draw_topdown": True,
        "draw_drone": True,
        "draw_trajectory": True,
        "draw_fov": False,
        "include_vdist": False,
        "include_layer": None,
        "include_instr": False
    }

    print("Loading data")
    train_envs, dev_envs, test_envs = get_restricted_env_id_lists()

    # TODO: Grab the correct env list
    env_list = dev_envs

    viz = RolloutVisualizer(resolution=576)
    base_dir = os.path.join(get_rollout_debug_viz_dir(), f"{dataset_name}-{domain}")
    os.makedirs(base_dir, exist_ok=True)

    for env_id in env_list:
        if specific_envs and env_id not in specific_envs:
            print("Skipping", env_id)
            continue
        try:
            env_data = load_single_env_from_dataset(dataset_name, env_id, "supervised")
        except FileNotFoundError as e:
            print(f"Skipping env: {env_id}")
            continue
        if len(env_data) == 0:
            print(f"Skipping env: {env_id}. Rollout exists but is EMPTY!")
            continue
        segs = split_into_segs(env_data)
        for seg in segs:
            lag_start = 1.5
            end_lag = 1.5
            seg_idx = seg[0]["seg_idx"]
            if specific_segments and (env_id, 0, seg_idx) not in specific_segments:
                continue
            seg_name = f"{env_id}:0:{seg_idx}-{domain}"
            gif_filename = f"{seg_name}-roll"
            instr_filename = f"{seg_name}-instr.txt"
            this_dir = os.path.join(base_dir, gif_filename)
            os.makedirs(this_dir, exist_ok=True)
            base_path = os.path.join(this_dir, gif_filename)
            if os.path.exists(os.path.join(this_dir, instr_filename)):
                continue

            # Animation with just the drone
            frames = viz.top_down_visualization(env_id, seg_idx, seg, domain, viz_params)
            save_frames(viz, frames, f"{base_path}-exec", fps=5.0, start_lag=lag_start, end_lag=end_lag, formats=Y_FMT)

            # Save instructionto
            with open(os.path.join(this_dir, instr_filename), "w") as fp:
                fp.write(seg[0]["instruction"])

            # Animation of action
            frames = viz.action_visualization(env_id, seg_idx, seg, domain, "action")
            save_frames(viz, frames, f"{base_path}-action", fps=5.0, start_lag=lag_start, end_lag=end_lag)

            # Animation of actions
            #action_frames = viz.grab_frames(env_id, seg_idx, seg, domain, "action", scale=4)
            #save_frames(viz, action_frames, f"{base_path}-action", fps=5.0, start_lag=lag_start, end_lag=end_lag)

            # Generate and save gif
            # Bare top-down view
            mod_params = deepcopy(viz_params)
            mod_params["draw_drone"] = False
            mod_params["draw_trajectory"] = False
            frames = viz.top_down_visualization(env_id, seg_idx, seg, domain, mod_params)
            save_frames(viz, frames, f"{base_path}-top-down", fps=5.0, start_lag=lag_start, end_lag=end_lag)

            mod_params["draw_drone"] = True
            mod_params["draw_trajectory"] = False
            frames = viz.top_down_visualization(env_id, seg_idx, seg, domain, mod_params)
            save_frames(viz, frames, f"{base_path}-top-down-drn", fps=5.0, start_lag=lag_start, end_lag=end_lag)

            # Egocentric visitation distributions
            vdist_r_frames = viz.grab_frames(env_id, seg_idx, seg, domain, "v_dist_r_inner")
            save_frames(viz, vdist_r_frames, f"{base_path}-ego-vdist", fps=5.0, start_lag=lag_start, end_lag=end_lag)

            # Map struct
            map_struct_frames = viz.grab_frames(env_id, seg_idx, seg, domain, "map_struct")
            save_frames(viz, map_struct_frames, f"{base_path}-ego-map-struct", fps=5.0, start_lag=lag_start, end_lag=end_lag)

            # Egocentric observation mask
            ego_obs_mask_frames = viz.grab_frames(env_id, seg_idx, seg, domain, "ego_obs_mask")
            save_frames(viz, ego_obs_mask_frames, f"{base_path}-ego-obs-mask", fps=5.0, start_lag=lag_start, end_lag=end_lag)

            def save_map_permutations(file_prefix, incl_layer):
                mod_params = deepcopy(viz_params)
                if incl_layer == "vdist":
                    mod_params["include_vdist"] = True
                else:
                    mod_params["include_layer"] = incl_layer
                print(f"GENERATING: {file_prefix}")
                # Non-overlaid, without trajectory
                mod_params["draw_drone"] = False
                mod_params["draw_topdown"] = False
                mod_params["draw_trajectory"] = False
                frames = viz.top_down_visualization(env_id, seg_idx, seg, domain, mod_params)
                save_frames(viz, frames, f"{file_prefix}", fps=5.0, start_lag=lag_start, end_lag=end_lag, formats=Y_FMT)

                print(f"GENERATING: {file_prefix}-ov")
                # Overlaid, without trajectory
                mod_params["draw_topdown"] = True
                frames = viz.top_down_visualization(env_id, seg_idx, seg, domain, mod_params)
                save_frames(viz, frames, f"{file_prefix}-ov", fps=5.0, start_lag=lag_start, end_lag=end_lag, formats=D_FMT)

                print(f"GENERATING: {file_prefix}-ov-path")
                # Overlaid, with trajectory
                mod_params["draw_drone"] = True
                mod_params["draw_trajectory"] = True
                frames = viz.top_down_visualization(env_id, seg_idx, seg, domain, mod_params)
                save_frames(viz, frames, f"{file_prefix}-ov-path", fps=5.0, start_lag=lag_start, end_lag=end_lag, formats=Y_FMT)

                print(f"GENERATING: {file_prefix}-path")
                # Non-overlaid, with trajectory
                mod_params["draw_topdown"] = False
                mod_params["draw_drone"] = True
                mod_params["draw_trajectory"] = True
                frames = viz.top_down_visualization(env_id, seg_idx, seg, domain, mod_params)
                save_frames(viz, frames, f"{file_prefix}-path", fps=5.0, start_lag=lag_start, end_lag=end_lag, formats=D_FMT)

            save_map_permutations(f"{base_path}-vdist", "vdist")

            save_map_permutations(f"{base_path}-semantic-map", "S_W")

            save_map_permutations(f"{base_path}-semantic-map-gray", "S_W_Gray")

            save_map_permutations(f"{base_path}-proj-features", "F_W")

            save_map_permutations(f"{base_path}-grounding-map", "R_W")

            save_map_permutations(f"{base_path}-grounding-map-gray", "R_W_Gray")

            save_map_permutations(f"{base_path}-mask", "M_W")

            save_map_permutations(f"{base_path}-accum-mask", "M_W_accum")

            save_map_permutations(f"{base_path}-accum-mask-inv", "M_W_accum_inv")

            # Animation of FPV features
            fpv_feature_frames = viz.grab_frames(env_id, seg_idx, seg, domain, "F_C")
            save_frames(viz, fpv_feature_frames, f"{base_path}-features-fpv", fps=5.0, start_lag=lag_start, end_lag=end_lag)

            # Animation of FPV images
            fpv_image_frames = viz.grab_frames(env_id, seg_idx, seg, domain, "image", scale=4)
            save_frames(viz, fpv_image_frames, f"{base_path}-image", fps=5.0, start_lag=lag_start, end_lag=end_lag)

            frames = viz.overlay_frames(fpv_image_frames, fpv_feature_frames)
            save_frames(viz, frames, f"{base_path}-features-fpv-ov", fps=5.0, start_lag=lag_start, end_lag=end_lag)

            num_frames = len(frames)

            # Clip rollout videos to correct rollout duration and re-save
            rollout_dir = get_rollout_video_dir(run_name=run_name)
            if os.path.isdir(rollout_dir):
                print("Processing rollout videos")
                actual_rollout_duration = num_frames / 5.0
                ceiling_clip = viz.load_video_clip(env_id, seg_idx, seg, domain, "ceiling", rollout_dir)
                duration_with_lag = lag_start + actual_rollout_duration + end_lag
                try:
                    if ceiling_clip is not None:
                        if ceiling_clip.duration > duration_with_lag:
                            start = ceiling_clip.duration - end_lag - duration_with_lag
                            ceiling_clip = ceiling_clip.cutout(0, start)
                            #ceiling_clip = ceiling_clip.cutout(duration_with_lag, ceiling_clip.duration)
                        save_frames(viz, ceiling_clip, f"{base_path}-ceiing_cam-clipped", fps=ceiling_clip.fps)
                    corner_clip = viz.load_video_clip(env_id, seg_idx, seg, domain, "corner", rollout_dir)
                    if corner_clip is not None:
                        if corner_clip.duration > actual_rollout_duration + end_lag:
                            start = corner_clip.duration - end_lag - duration_with_lag
                            corner_clip = corner_clip.cutout(0, start)
                            #corner_clip = corner_clip.cutout(duration_with_lag, corner_clip.duration)
                        save_frames(viz, corner_clip, f"{base_path}-corner_cam-clipped", fps=corner_clip.fps)
                except Exception as e:
                    print("Video encoding error! Copying manually")
                    print(e)

                try:
                    in_ceil_file = os.path.join(rollout_dir, f"rollout_ceiling_{env_id}-0-{seg_idx}.mkv")
                    in_corn_file = os.path.join(rollout_dir, f"rollout_corner_{env_id}-0-{seg_idx}.mkv")
                    out_ceil_file = f"{base_path}-ceiling_cam-full.mkv"
                    out_corn_file = f"{base_path}-corner_cam-full.mkv"
                    shutil.copy(in_ceil_file, out_ceil_file)
                    shutil.copy(in_corn_file, out_corn_file)
                except Exception as e:
                    print("Failed copying videos! SKipping")

        print("ding")
Ejemplo n.º 16
0
def provider_lm_pos_lm_indices_fpv(segment_data, data):
    """
    Data provider that gives the positions and indices of all landmarks visible in the FPV image.
    :param segment_data: segment dataset for which to provide data
    :return: ("lm_pos", lm_pos) - lm_pos is a list (over timesteps) of lists (over landmarks visible in image) of the
                landmark locations in image pixel coordinates
             ("lm_indices", lm_indices) - lm_indices is a list (over timesteps) of lists (over landmarks visible in image)
                of the landmark indices for every landmark included in lm_pos. These are the landmark classifier labels
    """
    env_id = segment_data[0]["metadata"]["env_id"]

    #if INSTRUCTIONS_FROM_FILE:
    #    env_instr = load_instructions(env_id)

    conf_json = load_env_config(env_id)
    all_landmark_indices = get_landmark_name_to_index()
    landmark_names, landmark_indices, landmark_pos = get_landmark_locations_airsim(
        conf_json)

    params = P.get_current_parameters().get(
        "ModelPVN") or P.get_current_parameters().get("Model")
    projector = PinholeCameraProjection(map_size=params["global_map_size"],
                                        map_world_size=params["world_size_px"],
                                        world_size=params["world_size_m"],
                                        img_x=params["img_w"],
                                        img_y=params["img_h"],
                                        cam_fov=params["cam_h_fov"],
                                        use_depth=False)
    traj_len = len(segment_data)

    lm_pos_fpv = []
    lm_indices = []
    lm_mentioned = []
    lm_pos_map = []

    for timestep in range(traj_len):
        t_lm_pos_fpv = []
        t_lm_indices = []
        t_lm_mentioned = []
        t_lm_pos_map = []

        if segment_data[timestep]["state"] is not None:
            cam_pos = segment_data[timestep]["state"].get_cam_pos_3d()
            cam_rot = segment_data[timestep]["state"].get_cam_rot()
            #if INSTRUCTIONS_FROM_FILE:
            #    instruction_str = env_instr
            #else:
            instruction_str = segment_data[timestep]["instruction"]
            mentioned_landmark_names, mentioned_landmark_indices = get_mentioned_landmarks_nl(
                instruction_str)

            for i, landmark_in_world in enumerate(landmark_pos):
                landmark_idx = landmark_indices[i]
                landmark_in_img, landmark_in_cam, status = projector.world_point_to_image(
                    cam_pos, cam_rot, landmark_in_world)
                this_lm_mentioned = 1 if landmark_idx in mentioned_landmark_indices else 0

                #landmark_in_map = world2map(landmark_in_world)

                # This is None if the landmark is behind the camera.
                if landmark_in_img is not None:
                    # presenter.save_image(images[timestep], name="tmp.png", torch=True, draw_point=landmark_in_img)
                    t_lm_pos_fpv.append(landmark_in_img[0:2])
                    t_lm_pos_map.append(landmark_in_world[0:2])
                    t_lm_indices.append(landmark_idx)
                    t_lm_mentioned.append(this_lm_mentioned)

        if len(t_lm_pos_fpv) > 0:
            t_lm_pos_fpv = torch.from_numpy(np.asarray(t_lm_pos_fpv)).float()
            t_lm_pos_map = torch.from_numpy(np.asarray(t_lm_pos_map)).float()
            t_lm_indices = torch.from_numpy(np.asarray(t_lm_indices)).long()
            t_lm_mentioned = torch.from_numpy(
                np.asarray(t_lm_mentioned)).long()
        else:
            t_lm_pos_fpv = None
            t_lm_pos_map = None
            t_lm_indices = None
            t_lm_mentioned = None

        lm_pos_fpv.append(t_lm_pos_fpv)
        lm_pos_map.append(t_lm_pos_map)
        lm_indices.append(t_lm_indices)
        lm_mentioned.append(t_lm_mentioned)

    return [("lm_pos_fpv", lm_pos_fpv), ("lm_indices", lm_indices),
            ("lm_mentioned", lm_mentioned), ("lm_pos_map", lm_pos_map)]
Ejemplo n.º 17
0
def train_dagger():
    P.initialize_experiment()
    global PARAMS
    PARAMS = P.get_current_parameters()["Dagger"]
    setup = P.get_current_parameters()["Setup"]
    roller = pick_policy_roller(setup)

    save_json(PARAMS,
              get_dagger_data_dir(setup, real_drone=False) + "run_params.json")

    # Load less tf data, but sample dagger rollouts from more environments to avoid overfitting.
    train_envs, dev_envs, test_envs = data_io.instructions.get_restricted_env_id_lists(
        max_envs=PARAMS["max_envs_dag"])

    all_train_data_real, all_dev_data_real = \
        data_io.train_data.load_supervised_data("real", max_envs=PARAMS["max_envs_sup"], split_segments=PARAMS["segment_level"])
    all_train_data_sim, all_dev_data_sim = \
        data_io.train_data.load_supervised_data("simulator", max_envs=PARAMS["max_envs_sup"], split_segments=PARAMS["segment_level"])

    print("Loaded data: ")
    print(
        f"   Real train {len(all_train_data_real)}, dev {len(all_dev_data_real)}"
    )
    print(
        f"   Sim train {len(all_train_data_sim)}, dev {len(all_dev_data_sim)}")

    # Load and re-save models from supervised learning stage
    model_sim, _ = load_model(setup["model"],
                              setup["sim_model_file"],
                              domain="sim")
    model_real, _ = load_model(setup["model"],
                               setup["real_model_file"],
                               domain="real")
    model_critic, _ = load_model(setup["critic_model"],
                                 setup["critic_model_file"])
    data_io.model_io.save_pytorch_model(
        model_sim, get_latest_model_filename(setup, "sim"))
    data_io.model_io.save_pytorch_model(
        model_real, get_latest_model_filename(setup, "real"))
    data_io.model_io.save_pytorch_model(
        model_critic, get_latest_model_filename(setup, "critic"))

    last_trainer_state = None

    for iteration in range(0, PARAMS["max_iterations"]):
        gc.collect()
        print("-------------------------------")
        print("DAGGER ITERATION : ", iteration)
        print("-------------------------------")

        # If we have too many training examples in memory, discard uniformly at random to keep a somewhat fixed bound
        max_samples = PARAMS["max_samples_in_memory"]
        all_train_data_real = discard_if_too_many(all_train_data_real,
                                                  max_samples)
        all_train_data_sim = discard_if_too_many(all_train_data_sim,
                                                 max_samples)

        # Roll out new data in simulation only
        latest_model_filename_sim = get_latest_model_filename(setup, "sim")
        train_data_i_sim, dev_data_i_sim = collect_iteration_data(
            roller, iteration, train_envs, test_envs,
            latest_model_filename_sim)

        # TODO: Save
        #data_io.train_data.save_dataset(dataset_name, train_data_i, dagger_data_dir + "train_" + str(iteration))
        #data_io.train_data.save_dataset(dataset_name, test_data_i, dagger_data_dir + "test_" + str(iteration))

        # Aggregate the dataset
        all_train_data_sim += train_data_i_sim
        all_dev_data_sim += dev_data_i_sim
        print("Aggregated dataset!)")
        print("Total samples: ", len(all_train_data_sim))
        print("New samples: ", len(train_data_i_sim))

        data_io.train_data.save_dataset(
            "sim_dagger", all_train_data_sim,
            get_dagger_data_dir(setup, False) + "train_latest")
        data_io.train_data.save_dataset(
            "sim_dagger", dev_data_i_sim,
            get_dagger_data_dir(setup, False) + "test_latest")

        model_sim, _ = load_model(setup["model"],
                                  get_latest_model_filename(setup, "sim"),
                                  domain="sim")
        model_real, _ = load_model(setup["model"],
                                   get_latest_model_filename(setup, "real"),
                                   domain="real")
        model_critic, _ = load_model(
            setup["critic_model"], get_latest_model_filename(setup, "critic"))

        trainer = TrainerBidomain(model_real,
                                  model_sim,
                                  model_critic,
                                  state=last_trainer_state)

        # Hacky reset of the rollout flag after doing the rollouts
        import rollout.run_metadata as run_md
        run_md.IS_ROLLOUT = False

        # Train on the newly aggregated dataset
        num_epochs = PARAMS["epochs_per_iteration"]
        for epoch in range(num_epochs):

            loss = trainer.train_epoch(data_list_real=all_train_data_real,
                                       data_list_sim=all_train_data_sim)
            dev_loss = trainer.train_epoch(data_list_real=all_dev_data_real,
                                           data_list_sim=dev_data_i_sim,
                                           eval=True)

            data_io.model_io.save_pytorch_model(
                model_sim, get_latest_model_filename(setup, "sim"))
            data_io.model_io.save_pytorch_model(
                model_real, get_latest_model_filename(setup, "real"))
            data_io.model_io.save_pytorch_model(
                model_critic, get_latest_model_filename(setup, "critic"))

            print("Epoch", epoch, "Loss: Train:", loss, "Test:", dev_loss)

        data_io.model_io.save_pytorch_model(
            model_real,
            get_model_filename_at_iteration(setup, iteration, "real"))
        data_io.model_io.save_pytorch_model(
            model_sim, get_model_filename_at_iteration(setup, iteration,
                                                       "sim"))
        data_io.model_io.save_pytorch_model(
            model_critic,
            get_model_filename_at_iteration(setup, iteration, "critic"))

        last_trainer_state = trainer.get_state()
Ejemplo n.º 18
0
def evaluate():
    P.initialize_experiment()

    model, model_loaded = load_model()
    eval_envs = get_correct_eval_env_id_list()

    model.eval()
    dataset_name = P.get_current_parameters().get("Data").get("dataset_name")
    dataset = model.get_dataset(data=None,
                                envs=eval_envs,
                                dataset_prefix=dataset_name,
                                dataset_prefix="supervised",
                                eval=eval,
                                seg_level=False)
    dataloader = DataLoader(dataset,
                            collate_fn=dataset.collate_fn,
                            batch_size=1,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True,
                            timeout=0)

    count = 0
    success = 0
    total_dist = 0

    for batch in dataloader:
        if batch is None:
            print("None batch!")
            continue

        images = batch["images"]
        instructions = batch["instr"]
        label_masks = batch["traj_labels"]

        # Each of the above is a list of lists of tensors, where the outer list is over the batch and the inner list
        # is over the segments. Loop through and accumulate loss for each batch sequentially, and for each segment.
        # Reset model state (embedding etc) between batches, but not between segments.
        # We don't process each batch in batch-mode, because it's complicated, with the varying number of segments and all.
        # TODO: This code is outdated and wrongly discretizes the goal location. Grab the fixed version from the old branch.

        batch_size = len(images)
        print("batch: ", count)
        print("successes: ", success)

        for i in range(batch_size):
            num_segments = len(instructions[i])

            for s in range(num_segments):
                instruction = cuda_var(instructions[i][s], model.is_cuda,
                                       model.cuda_device)
                instruction_mask = torch.ones_like(instruction)
                image = cuda_var(images[i][s], model.is_cuda,
                                 model.cuda_device)
                label_mask = cuda_var(label_masks[i][s], model.is_cuda,
                                      model.cuda_device)

                label_mask = model.label_pool(label_mask)

                goal_mask_l = label_mask[0, 1, :, :]
                goal_mask_l_np = goal_mask_l.data.cpu().numpy()
                goal_mask_l_flat = np.reshape(goal_mask_l_np, [-1])
                max_index_l = np.argmax(goal_mask_l_flat)
                argmax_loc_l = np.asarray([
                    int(max_index_l / goal_mask_l_np.shape[1]),
                    int(max_index_l % goal_mask_l_np.shape[1])
                ])

                if np.sum(goal_mask_l_np) < 0.01:
                    continue

                mask_pred, features, emb_loss = model(image, instruction,
                                                      instruction_mask)
                goal_mask = mask_pred[0, 1, :, :]
                goal_mask_np = goal_mask.data.cpu().numpy()
                goal_mask_flat = np.reshape(goal_mask_np, [-1])
                max_index = np.argmax(goal_mask_flat)

                argmax_loc = np.asarray([
                    int(max_index / goal_mask_np.shape[1]),
                    int(max_index % goal_mask_np.shape[1])
                ])

                dist = np.linalg.norm(argmax_loc - argmax_loc_l)
                if dist < OK_DIST:
                    success += 1
                count += 1
                total_dist += dist

    print("Correct goal predictions: ", success)
    print("Total evaluations: ", count)
    print("total dist: ", total_dist)
    print("avg dist: ", total_dist / float(count))
    print("success rate: ", success / float(count))
Ejemplo n.º 19
0
def train_supervised():
    initialize_experiment()

    setup = get_current_parameters()["Setup"]
    supervised_params = get_current_parameters()["Supervised"]
    num_epochs = supervised_params["num_epochs"]

    model, model_loaded = load_model()
    # import pdb; pdb.set_trace()
    # import pickle
    # with open('/storage/dxsun/model_input.pickle', 'rb') as f: data = pickle.load(f)
    # g = model(data['images'], data['states'], data['instructions'], data['instr_lengths'], data['has_obs'], data['plan'], data['save_maps_only'], data['pos_enc'], data['noisy_poses'], data['start_poses'], data['firstseg'])
    print("model:", model)
    print("model type:", type(model))
    print("Loading data")
    # import pdb;pdb.set_trace()
    train_envs, dev_envs, test_envs = get_all_env_id_lists(
        max_envs=setup["max_envs"])
    if "split_train_data" in supervised_params and supervised_params[
            "split_train_data"]:
        split_name = supervised_params["train_data_split"]
        split = load_env_split()[split_name]
        train_envs = [env_id for env_id in train_envs if env_id in split]
        print("Using " + str(len(train_envs)) + " envs from dataset split: " +
              split_name)

    filename = "supervised_" + setup["model"] + "_" + setup["run_name"]

    # Code looks weird here because load_pytorch_model adds ".pytorch" to end of path, but
    # file_exists doesn't
    model_path = "tmp/" + filename + "_epoch_" + str(
        supervised_params["start_epoch"])
    model_path_with_extension = model_path + ".pytorch"
    print("model path:", model_path_with_extension)
    if supervised_params["start_epoch"] > 0:
        if file_exists(model_path_with_extension):
            print("THE FILE EXISTS code1")
            load_pytorch_model(model, model_path)
        else:
            print("Couldn't continue training. Model file doesn't exist at:")
            print(model_path_with_extension)
            exit(-1)
    # import pdb;pdb.set_trace()
    ## If you just want to use the pretrained model
    # load_pytorch_model(model, "supervised_pvn_stage1_train_corl_pvn_stage1")

    # all_train_data, all_test_data = data_io.train_data.load_supervised_data(max_envs=100)
    if setup["restore_weights_name"]:
        restore_pretrained_weights(model, setup["restore_weights_name"],
                                   setup["fix_restored_weights"])

    # Add a tensorboard logger to the model and trainer
    tensorboard_dir = get_current_parameters(
    )['Environment']['tensorboard_dir']
    logger = Logger(tensorboard_dir)
    model.logger = logger
    if hasattr(model, "goal_good_criterion"):
        print("gave logger to goal evaluator")
        model.goal_good_criterion.logger = logger

    trainer = Trainer(model,
                      epoch=supervised_params["start_epoch"],
                      name=setup["model"],
                      run_name=setup["run_name"])

    trainer.logger = logger

    # import pdb;pdb.set_trace()
    print("Beginning training...")
    best_test_loss = 1000

    continue_epoch = supervised_params["start_epoch"] + 1 if supervised_params[
        "start_epoch"] > 0 else 0
    rng = range(0, num_epochs)
    print("filename:", filename)

    import pdb
    pdb.set_trace()

    for epoch in rng:
        # import pdb;pdb.set_trace()
        train_loss = trainer.train_epoch(train_data=None,
                                         train_envs=train_envs,
                                         eval=False)
        # train_loss = trainer.train_epoch(train_data=all_train_data, train_envs=train_envs, eval=False)

        trainer.model.correct_goals = 0
        trainer.model.total_goals = 0

        test_loss = trainer.train_epoch(train_data=None,
                                        train_envs=dev_envs,
                                        eval=True)

        print("GOALS: ", trainer.model.correct_goals,
              trainer.model.total_goals)

        if test_loss < best_test_loss:
            best_test_loss = test_loss
            save_pytorch_model(trainer.model, filename)
            print("Saved model in:", filename)
        print("Epoch", epoch, "train_loss:", train_loss, "test_loss:",
              test_loss)
        save_pytorch_model(trainer.model,
                           "tmp/" + filename + "_epoch_" + str(epoch))
        if hasattr(trainer.model, "save"):
            trainer.model.save(epoch)
        save_pretrained_weights(trainer.model, setup["run_name"])
Ejemplo n.º 20
0
    def __init__(self, run_name="", model_instance_name=""):

        super(ModelGSMNBiDomain, self).__init__()
        self.model_name = "gsmn_bidomain"
        self.run_name = run_name
        self.name = model_instance_name
        if not self.name:
            self.name = ""
        self.writer = LoggingSummaryWriter(
            log_dir=f"runs/{run_name}/{self.name}")

        self.params = get_current_parameters()["Model"]
        self.aux_weights = get_current_parameters()["AuxWeights"]
        self.use_aux = self.params["UseAuxiliaries"]

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        self.tensor_store = KeyTensorStore()
        self.aux_losses = AuxiliaryLosses()

        self.rviz = None
        if self.params.get("rviz"):
            self.rviz = RvizInterface(
                base_name="/gsmn/",
                map_topics=["semantic_map", "grounding_map", "goal_map"],
                markerarray_topics=["instruction"])

        # Path-pred FPV model definition
        # --------------------------------------------------------------------------------------------------------------

        self.img_to_features_w = FPVToGlobalMap(
            source_map_size=self.params["global_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"],
            res_channels=self.params["resnet_channels"],
            map_channels=self.params["feature_channels"],
            img_w=self.params["img_w"],
            img_h=self.params["img_h"],
            cam_h_fov=self.params["cam_h_fov"],
            img_dbg=IMG_DBG)

        self.map_accumulator_w = LeakyIntegratorGlobalMap(
            source_map_size=self.params["global_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"])

        # Pre-process the accumulated map to do language grounding if necessary - in the world reference frame
        if self.use_aux[
                "grounding_map"] and not self.use_aux["grounding_features"]:
            self.map_processor_a_w = LangFilterMapProcessor(
                embed_size=self.params["emb_size"],
                in_channels=self.params["feature_channels"],
                out_channels=self.params["relevance_channels"],
                spatial=False,
                cat_out=True)
        else:
            self.map_processor_a_w = IdentityMapProcessor(
                source_map_size=self.params["global_map_size"],
                world_size_px=self.params["world_size_px"],
                world_size_m=self.params["world_size_m"])

        if self.use_aux["goal_map"]:
            self.map_processor_b_r = LangFilterMapProcessor(
                embed_size=self.params["emb_size"],
                in_channels=self.params["relevance_channels"],
                out_channels=self.params["goal_channels"],
                spatial=self.params["spatial_goal_filter"],
                cat_out=self.params["cat_rel_and_goal"])
        else:
            self.map_processor_b_r = IdentityMapProcessor(
                source_map_size=self.params["local_map_size"],
                world_size_px=self.params["world_size_px"],
                world_size_m=self.params["world_size_m"])

        # Common
        # --------------------------------------------------------------------------------------------------------------

        # Sentence Embedding
        self.sentence_embedding = SentenceEmbeddingSimple(
            self.params["word_emb_size"],
            self.params["emb_size"],
            self.params["emb_layers"],
            dropout=0.0)

        self.map_transform_w_to_r = MapTransformerBase(
            source_map_size=self.params["global_map_size"],
            dest_map_size=self.params["local_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"])
        self.map_transform_r_to_w = MapTransformerBase(
            source_map_size=self.params["local_map_size"],
            dest_map_size=self.params["global_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"])

        # Output an action given the global semantic map
        if self.params["map_to_action"] == "downsample2":
            self.map_to_action = EgoMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"],
                other_features_size=self.params["emb_size"])

        elif self.params["map_to_action"] == "cropped":
            self.map_to_action = CroppedMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"])

        # Auxiliary Objectives
        # --------------------------------------------------------------------------------------------------------------

        # We add all auxiliaries that are necessary. The first argument is the auxiliary name, followed by parameters,
        # followed by variable number of names of inputs. ModuleWithAuxiliaries will automatically collect these inputs
        # that have been saved with keep_auxiliary_input() during execution
        if self.use_aux["class_features"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary2D("aux_class", self.params["feature_channels"],
                                 self.params["num_landmarks"],
                                 self.params["dropout"], "fpv_features",
                                 "lm_pos_fpv_features", "lm_indices",
                                 "tensor_store"))
        if self.use_aux["grounding_features"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary2D("aux_ground",
                                 self.params["relevance_channels"], 2,
                                 self.params["dropout"], "fpv_features_g",
                                 "lm_pos_fpv_features", "lm_mentioned",
                                 "tensor_store"))
        if self.use_aux["class_map"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary2D("aux_class_map",
                                 self.params["feature_channels"],
                                 self.params["num_landmarks"],
                                 self.params["dropout"], "map_S_W",
                                 "lm_pos_map", "lm_indices", "tensor_store"))
        if self.use_aux["grounding_map"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary2D("aux_grounding_map",
                                 self.params["relevance_channels"], 2,
                                 self.params["dropout"], "map_R_W",
                                 "lm_pos_map", "lm_mentioned", "tensor_store"))
        if self.use_aux["goal_map"]:
            self.aux_losses.add_auxiliary(
                GoalAuxiliary2D("aux_goal_map", self.params["goal_channels"],
                                self.params["global_map_size"], "map_G_W",
                                "goal_pos_map"))
        # RSS model uses templated data for landmark and side prediction
        if self.use_aux["language"] and self.params["templates"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary("aux_lang_lm", self.params["emb_size"],
                               self.params["num_landmarks"], 1,
                               "sentence_embed", "lm_mentioned_tplt"))
            self.aux_losses.add_auxiliary(
                ClassAuxiliary("aux_lang_side", self.params["emb_size"],
                               self.params["num_sides"], 1, "sentence_embed",
                               "side_mentioned_tplt"))
        # CoRL model uses alignment-model groundings
        elif self.use_aux["language"]:
            # one output for each landmark, 2 classes per output. This is for finetuning, so use the embedding that's gonna be fine tuned
            self.aux_losses.add_auxiliary(
                ClassAuxiliary("aux_lang_lm_nl", self.params["emb_size"], 2,
                               self.params["num_landmarks"], "sentence_embed",
                               "lang_lm_mentioned"))
        if self.use_aux["l1_regularization"]:
            self.aux_losses.add_auxiliary(
                FeatureRegularizationAuxiliary2D("aux_regularize_features",
                                                 "l1", "map_S_W"))
            self.aux_losses.add_auxiliary(
                FeatureRegularizationAuxiliary2D("aux_regularize_features",
                                                 "l1", "map_R_W"))

        self.goal_acc_meter = MovingAverageMeter(10)

        self.aux_losses.print_auxiliary_info()

        self.action_loss = ActionLoss()

        self.env_id = None
        self.prev_instruction = None
        self.seq_step = 0
Ejemplo n.º 21
0
Archivo: paths.py Proyecto: dxsun/drif
def get_sim_executable_path():
    return get_current_parameters()["Environment"]["simulator_path"]
Ejemplo n.º 22
0
def automatic_demo():

    P.initialize_experiment()
    instruction_display = InstructionDisplay()

    rate = Rate(0.1)

    env = PomdpInterface(
        is_real=get_current_parameters()["Setup"]["real_drone"])
    train_instructions, dev_instructions, test_instructions, corpus = get_all_instructions(
    )
    all_instr = {
        **train_instructions,
        **dev_instructions,
        **train_instructions
    }
    token2term, word2token = get_word_to_token_map(corpus)

    # Run on dev set
    interact_instructions = dev_instructions

    env_range_start = get_current_parameters()["Setup"].get(
        "env_range_start", 0)
    env_range_end = get_current_parameters()["Setup"].get(
        "env_range_end", 10e10)
    interact_instructions = {
        k: v
        for k, v in interact_instructions.items()
        if env_range_start < k < env_range_end
    }

    model, _ = load_model(get_current_parameters()["Setup"]["model"])

    # Loop over the select few examples
    while True:

        for instruction_sets in interact_instructions.values():
            for set_idx, instruction_set in enumerate(instruction_sets):
                env_id = instruction_set['env']
                found_example = None
                for example in examples:
                    if example[0] == env_id:
                        found_example = example
                if found_example is None:
                    continue
                env.set_environment(env_id, instruction_set["instructions"])

                presenter = Presenter()
                cumulative_reward = 0
                for seg_idx in range(len(instruction_set["instructions"])):
                    if seg_idx != found_example[2]:
                        continue

                    print(f"RUNNING ENV {env_id} SEG {seg_idx}")

                    real_instruction_str = instruction_set["instructions"][
                        seg_idx]["instruction"]
                    instruction_display.show_instruction(real_instruction_str)
                    valid_segment = env.set_current_segment(seg_idx)
                    if not valid_segment:
                        continue
                    state = env.reset(seg_idx)

                    for i in range(START_PAUSE):
                        instruction_display.tick()
                        time.sleep(1)

                        tok_instruction = tokenize_instruction(
                            real_instruction_str, word2token)

                    state = env.reset(seg_idx)
                    print("Executing: f{instruction_str}")
                    while True:
                        instruction_display.tick()
                        rate.sleep()
                        action, internals = model.get_action(
                            state, tok_instruction)
                        state, reward, done, expired, oob = env.step(action)
                        cumulative_reward += reward
                        #presenter.show_sample(state, action, reward, cumulative_reward, real_instruction_str)
                        #show_depth(state.image)
                        if done:
                            break

                    for i in range(END_PAUSE):
                        instruction_display.tick()
                        time.sleep(1)
                        print("Segment finished!")
                    instruction_display.show_instruction("...")

            print("Env finished!")
Ejemplo n.º 23
0
Archivo: paths.py Proyecto: dxsun/drif
def get_sim_config_dir():
    directory = get_current_parameters()["Environment"]["sim_config_dir"]
    print("get_sim_config_dir:", directory)
    return directory
Ejemplo n.º 24
0
    def __init__(self,
                 model_real,
                 model_sim,
                 model_critic,
                 model_oracle_critic=None,
                 state=None,
                 epoch=0):
        _, _, _, corpus = get_all_instructions()
        self.token2word, self.word2token = get_word_to_token_map(corpus)

        self.params = get_current_parameters()["Training"]
        self.run_name = get_current_parameters()["Setup"]["run_name"]
        self.batch_size = self.params['batch_size']
        self.iterations_per_epoch = self.params.get("iterations_per_epoch",
                                                    None)
        self.weight_decay = self.params['weight_decay']
        self.optimizer = self.params['optimizer']
        self.critic_loaders = self.params['critic_loaders']
        self.model_common_loaders = self.params['model_common_loaders']
        self.model_sim_loaders = self.params['model_sim_loaders']
        self.lr = self.params['lr']
        self.critic_steps = self.params['critic_steps']
        self.model_steps = self.params['model_steps']
        self.critic_batch_size = self.params["critic_batch_size"]
        self.model_batch_size = self.params["model_batch_size"]
        self.disable_wloss = self.params["disable_wloss"]
        self.sim_steps_per_real_step = self.params.get(
            "sim_steps_per_real_step", 1)

        self.real_dataset_names = self.params.get("real_dataset_names")
        self.sim_dataset_names = self.params.get("sim_dataset_names")

        self.bidata = self.params.get("bidata", False)
        self.sim_steps_per_common_step = self.params.get(
            "sim_steps_per_common_step", 1)

        n_params_real = get_n_params(model_real)
        n_params_real_tr = get_n_trainable_params(model_real)
        n_params_sim = get_n_params(model_sim)
        n_params_sim_tr = get_n_trainable_params(model_sim)
        n_params_c = get_n_params(model_critic)
        n_params_c_tr = get_n_params(model_critic)

        print("Training Model:")
        print("Real # model parameters: " + str(n_params_real))
        print("Real # trainable parameters: " + str(n_params_real_tr))
        print("Sim  # model parameters: " + str(n_params_sim))
        print("Sim  # trainable parameters: " + str(n_params_sim_tr))
        print("Critic  # model parameters: " + str(n_params_c))
        print("Critic  # trainable parameters: " + str(n_params_c_tr))

        # Share those modules that are to be shared between real and sim models
        if not self.params.get("disable_domain_weight_sharing"):
            print("Sharing weights between sim and real modules")
            model_real.steal_cross_domain_modules(model_sim)
        else:
            print("NOT Sharing weights between sim and real modules")

        self.model_real = model_real
        self.model_sim = model_sim
        self.model_critic = model_critic
        self.model_oracle_critic = model_oracle_critic
        if self.model_oracle_critic:
            print("Using oracle critic")

        if self.optimizer == "adam":
            Optim = optim.Adam
        elif self.optimizer == "sgd":
            Optim = optim.SGD
        else:
            raise ValueError(f"Unsuppored optimizer {self.optimizer}")

        self.optim_models = Optim(self.model_real.both_domain_parameters(
            self.model_sim),
                                  self.lr,
                                  weight_decay=self.weight_decay)
        self.optim_critic = Optim(self.critic_parameters(),
                                  self.lr,
                                  weight_decay=self.weight_decay)

        self.train_epoch_num = epoch
        self.train_segment = 0
        self.test_epoch_num = epoch
        self.test_segment = 0
        self.set_state(state)
Ejemplo n.º 25
0
 def __init__(self):
     self.headless = P.get_current_parameters()["Environment"].get(
         "headless", False)
     self.drone_image = None
     self.coord_grid = None
Ejemplo n.º 26
0
def evaluate():
    P.initialize_experiment()
    params = P.get_current_parameters()
    setup = params["Setup"]

    models = []
    for i in range(setup["num_workers"]):
        model, model_loaded = load_model()
        models.append(model)

    eval_envs = list(sorted(get_correct_eval_env_id_list()))

    round_size = P.get_current_parameters()["Data"].get("collect_n_at_a_time")

    # TODO: Scrap RollOutParams and use parameter server JSON params instead
    roll_out_params = RollOutParams() \
                        .setModelName(setup["model"]) \
                        .setModelFile(setup["model_file"]) \
                        .setRunName(setup["run_name"]) \
                        .setSetupName(P.get_setup_name()) \
                        .setEnvList(eval_envs) \
                        .setMaxDeviation(800) \
                        .setHorizon(setup["trajectory_length"]) \
                        .setStepsToForceStop(20) \
                        .setPlot(False) \
                        .setShowAction(False) \
                        .setIgnorePolicyStop(False) \
                        .setPlotDir("evaluate/" + setup["run_name"]) \
                        .setSavePlots(False) \
                        .setRealtimeFirstPerson(False) \
                        .setSaveSamples(False) \
                        .setBuildTrainData(False) \
                        .setSegmentReset("always") \
                        .setSegmentLevel(False) \
                        .setFirstSegmentOnly(False) \
                        .setDebug(setup["debug"]) \
                        .setCuda(setup["cuda"]) \
                        .setRealDrone(setup["real_drone"])

    custom_eval = "Eval" in params and params["Eval"]["custom_eval"]
    instructions = None
    if custom_eval:
        examples = params["Eval"]["examples"]
        eval_envs, eval_sets, eval_segs, instructions = tuple(
            map(lambda m: list(m), list(zip(*examples))))
        print("!! Running custom evaluation with the following setup:")
        print(examples)
        roll_out_params.setEnvList(eval_envs)
        roll_out_params.setSegList(eval_segs)
        roll_out_params.setCustomInstructions(instructions)

    if setup["num_workers"] > 1:
        roller = ParallelPolicyRoller(num_workers=setup["num_workers"])
    else:
        roller = PolicyRoller()

    if round_size:
        eval_dataset_name = data_io.paths.get_eval_tmp_dataset_name(
            setup["model"], setup["run_name"])
        eval_dataset_path = data_io.paths.get_dataset_dir(eval_dataset_name)

        cumulative_dataset = []
        if os.path.exists(eval_dataset_path):
            result = query_user_load_discard(eval_dataset_path)
            if result == "load":
                print("Loading dataset and continuing evaluation")
                cumulative_dataset = load_multiple_env_data_from_dir(
                    eval_dataset_path)
            elif result == "discard":
                print("Discarding existing evaluation data")
                shutil.rmtree(eval_dataset_path)
            elif result == "cancel":
                print("Cancelling evaluation")
                return

        os.makedirs(eval_dataset_path, exist_ok=True)

        collected_envs = set([
            rollout[0]["env_id"] for rollout in cumulative_dataset
            if len(rollout) > 0
        ])
        eval_envs = [e for e in eval_envs if e not in collected_envs]
        if setup.get("compute_results_no_rollout", False):
            eval_envs = []

        for i in range(0, len(eval_envs), round_size):
            j = min(len(eval_envs), i + round_size)
            round_envs = eval_envs[i:j]
            roll_out_params.setEnvList(round_envs)
            dataset = roller.roll_out_policy(roll_out_params)

            # Save this data
            for rollout in dataset:
                if len(rollout) == 0:
                    print(
                        "WARNING! DROPPING EMPTY ROLLOUTS! SHOULDN'T DO THIS")
                    continue
                ## rollout is a list of samples:
                env_id = rollout[0]["env_id"] if "metadata" in rollout[
                    0] else rollout[0]["env_id"]
                if True:
                    if len(rollout) > 0:
                        save_dataset_to_path(
                            os.path.join(eval_dataset_path, str(env_id)),
                            rollout)
                ## rollout is a list of segments, each is a list of samples
                else:
                    if len(rollout) > 0:
                        save_dataset_to_path(
                            os.path.join(eval_dataset_path, str(env_id)),
                            rollout)

            cumulative_dataset += dataset
            print(f"Saved cumulative dataset to: {eval_dataset_path}")

        dataset = cumulative_dataset
    else:
        dataset = roller.roll_out_policy(roll_out_params)

    results = {}
    if setup["eval_landmark_side"]:
        evaler = DataEvalLandmarkSide(setup["run_name"],
                                      save_images=True,
                                      world_size=setup["world_size_m"])
        evaler.evaluate_dataset(dataset)
        results = evaler.get_results()
    if setup["eval_nl"]:
        evaler = DataEvalNL(setup["run_name"],
                            save_images=True,
                            entire_trajectory=False,
                            custom_instr=instructions)
        evaler.evaluate_dataset(dataset)
        results = evaler.get_results()

    print("Results:", results)
Ejemplo n.º 27
0
def collect_data_on_env_list(env_list):
    setup = P.get_current_parameters()["Setup"]
    dataset_name = P.get_current_parameters()["Data"]["dataset_name"]

    if setup["num_workers"] > 1:
        roller = ParallelPolicyRoller(num_workers=setup["num_workers"])
    else:
        roller = PolicyRoller()

    group_size = P.get_current_parameters()["Data"].get(
        "collect_n_at_a_time", 5)

    wrong_paths_p = P.get_current_parameters()["Rollout"].get(
        "wrong_path_p", 0.0)

    # setSetupName is important - it allows the threads to load the same json file and initialize stuff correctly
    roll_params = RollOutParams() \
        .setModelName("oracle") \
        .setRunName(setup["run_name"]) \
        .setSetupName(P.get_setup_name()) \
        .setSavePlots(False) \
        .setSaveSamples(False) \
        .setSegmentLevel(False) \
        .setPlot(False) \
        .setBuildTrainData(False) \
        .setRealDrone(setup["real_drone"]) \
        .setCuda(setup["cuda"]) \
        .setSegmentReset("always") \
        .setWrongPathP(wrong_paths_p)

    # Collect training data
    print("Collecting training data!")

    if setup.get("env_range_start") > 0:
        env_list = [e for e in env_list if e >= setup["env_range_start"]]
    if setup.get("env_range_end") > 0:
        env_list = [e for e in env_list if e < setup["env_range_end"]]

    env_list = env_list[:setup["max_envs"]]
    env_list = filter_uncollected_envs(dataset_name, env_list)

    group_size = setup["num_workers"] * group_size

    kill_airsim_every_n_rounds = 50
    round_counter = 0

    for i in range(0, len(env_list), group_size):
        # Rollout on group_size envs at a time. After each group, land the drone and save the data
        round_envs = env_list[i:]
        round_envs = round_envs[:group_size]
        roll_params.setEnvList(round_envs)
        env_datas = roller.roll_out_policy(roll_params)
        for j in range(len(env_datas)):
            env_data = env_datas[j]
            if len(env_data) > 0:
                # KeyError: 0:
                env_id = env_data[0]["env_id"]
                filename = get_supervised_data_filename(env_id)
                save_dataset(dataset_name, env_data, filename)
            else:
                print("Empty rollout!")
        # AirSim tends to clog up and become slow. Kill it every so often to restart it.
        round_counter += 1
        if round_counter > kill_airsim_every_n_rounds:
            round_counter = 0
            killAirSim(do_kill=True)
Ejemplo n.º 28
0
 def __init__(self):
     params = get_current_parameters()["BaselineStraight"]
     self.current_step = 0
     self.avg_fwd_vel = params["AvgSpeed"]
     self.avg_num_steps = params["AvgSteps"]
Ejemplo n.º 29
0
def train_supervised_worker(rl_process_conn):
    P.initialize_experiment()
    setup = P.get_current_parameters()["Setup"]
    rlsup = P.get_current_parameters()["RLSUP"]
    setup["trajectory_length"] = setup["sup_trajectory_length"]
    run_name = setup["run_name"]
    supervised_params = P.get_current_parameters()["Supervised"]
    num_epochs = supervised_params["num_epochs"]
    sup_device = rlsup.get("sup_device", "cuda:1")

    model_oracle_critic = None

    print("SUPP: Loading data")
    train_envs, dev_envs, test_envs = get_restricted_env_id_lists()

    # Load the starter model and save it at epoch 0
    # Supervised worker to use GPU 1, RL will use GPU 0. Simulators run on GPU 2
    model_sim = load_model(setup["sup_model"],
                           setup["sim_model_file"],
                           domain="sim")[0].to(sup_device)
    model_real = load_model(setup["sup_model"],
                            setup["real_model_file"],
                            domain="real")[0].to(sup_device)
    model_critic = load_model(setup["sup_critic_model"],
                              setup["critic_model_file"])[0].to(sup_device)

    # ----------------------------------------------------------------------------------------------------------------

    print("SUPP: Initializing trainer")
    rlsup_params = P.get_current_parameters()["RLSUP"]
    sim_seed_dataset = rlsup_params.get("sim_seed_dataset")

    # TODO: Figure if 6000 or 7000 here
    trainer = TrainerBidomainBidata(model_real,
                                    model_sim,
                                    model_critic,
                                    model_oracle_critic,
                                    epoch=0)
    train_envs_common = [e for e in train_envs if 6000 <= e < 7000]
    train_envs_sim = [e for e in train_envs if e < 7000]
    dev_envs_common = [e for e in dev_envs if 6000 <= e < 7000]
    dev_envs_sim = [e for e in dev_envs if e < 7000]
    sim_datasets = [rl_dataset_name(run_name)]
    real_datasets = ["real"]
    trainer.set_dataset_names(sim_datasets=sim_datasets,
                              real_datasets=real_datasets)

    # ----------------------------------------------------------------------------------------------------------------
    for start_sup_epoch in range(10000):
        epfname = epoch_sup_filename(run_name,
                                     start_sup_epoch,
                                     model="stage1",
                                     domain="sim")
        path = os.path.join(get_model_dir(), str(epfname) + ".pytorch")
        if not os.path.exists(path):
            break
    if start_sup_epoch > 0:
        print(f"SUPP: CONTINUING SUP TRAINING FROM EPOCH: {start_sup_epoch}")
        load_pytorch_model(
            model_real,
            epoch_sup_filename(run_name,
                               start_sup_epoch - 1,
                               model="stage1",
                               domain="real"))
        load_pytorch_model(
            model_sim,
            epoch_sup_filename(run_name,
                               start_sup_epoch - 1,
                               model="stage1",
                               domain="sim"))
        load_pytorch_model(
            model_critic,
            epoch_sup_filename(run_name,
                               start_sup_epoch - 1,
                               model="critic",
                               domain="critic"))
        trainer.set_start_epoch(start_sup_epoch)

    # ----------------------------------------------------------------------------------------------------------------
    print("SUPP: Beginning training...")
    for epoch in range(start_sup_epoch, num_epochs):
        # Tell the RL process that a new Stage 1 model is ready for loading
        print("SUPP: Sending model to RL")
        model_sim.reset()
        rl_process_conn.send(
            ["stage1_model_state_dict",
             model_sim.state_dict()])
        if DEBUG_RL:
            while True:
                sleep(1)

        if not sim_seed_dataset:
            ddir = get_dataset_dir(rl_dataset_name(run_name))
            os.makedirs(ddir, exist_ok=True)
            while len(os.listdir(ddir)) < 20:
                print("SUPP: Waiting for rollouts to appear")
                sleep(3)

        print("SUPP: Beginning Epoch")
        train_loss = trainer.train_epoch(env_list_common=train_envs_common,
                                         env_list_sim=train_envs_sim,
                                         eval=False)
        test_loss = trainer.train_epoch(env_list_common=dev_envs_common,
                                        env_list_sim=dev_envs_sim,
                                        eval=True)
        print("SUPP: Epoch", epoch, "train_loss:", train_loss, "test_loss:",
              test_loss)
        save_pytorch_model(
            model_real,
            epoch_sup_filename(run_name, epoch, model="stage1", domain="real"))
        save_pytorch_model(
            model_sim,
            epoch_sup_filename(run_name, epoch, model="stage1", domain="sim"))
        save_pytorch_model(
            model_critic,
            epoch_sup_filename(run_name,
                               epoch,
                               model="critic",
                               domain="critic"))
Ejemplo n.º 30
0
    def sup_loss_on_batch(self, batch, eval):
        self.prof.tick("out")

        action_loss_total = Variable(empty_float_tensor([1], self.is_cuda, self.cuda_device))

        if batch is None:
            print("Skipping None Batch")
            return action_loss_total

        images = self.maybe_cuda(batch["images"])

        instructions = self.maybe_cuda(batch["instr"])
        instr_lengths = batch["instr_len"]
        states = self.maybe_cuda(batch["states"])
        actions = self.maybe_cuda(batch["actions"])

        # Auxiliary labels
        lm_pos_fpv = batch["lm_pos_fpv"]
        lm_indices = batch["lm_indices"]
        lm_mentioned = batch["lm_mentioned"]
        lang_lm_mentioned = batch["lang_lm_mentioned"]

        templates = get_current_parameters()["Environment"]["Templates"]
        if templates:
            lm_mentioned_tplt = batch["lm_mentioned_tplt"]
            side_mentioned_tplt = batch["side_mentioned_tplt"]

        # stops = self.maybe_cuda(batch["stops"])
        masks = self.maybe_cuda(batch["masks"])
        metadata = batch["md"]

        seq_len = images.size(1)
        batch_size = images.size(0)
        count = 0
        correct_goal_count = 0
        goal_count = 0

        # Loop thru batch
        for b in range(batch_size):
            seg_idx = -1

            self.reset()

            self.prof.tick("out")
            b_seq_len = len_until_nones(metadata[b])

            # TODO: Generalize this
            # Slice the data according to the sequence length
            b_metadata = metadata[b][:b_seq_len]
            b_images = images[b][:b_seq_len]
            b_instructions = instructions[b][:b_seq_len]
            b_instr_len = instr_lengths[b][:b_seq_len]
            b_states = states[b][:b_seq_len]
            b_actions = actions[b][:b_seq_len]
            b_lm_pos_fpv = lm_pos_fpv[b][:b_seq_len]
            b_lm_indices = lm_indices[b][:b_seq_len]
            b_lm_mentioned = lm_mentioned[b][:b_seq_len]

            b_lm_pos_fpv = [self.cuda_var((s / RESNET_FACTOR).long()) if s is not None else None for s in b_lm_pos_fpv]
            b_lm_indices = [self.cuda_var(s) if s is not None else None for s in b_lm_indices]
            b_lm_mentioned = [self.cuda_var(s) if s is not None else None for s in b_lm_mentioned]

            # TODO: Figure out how to keep these properly. Perhaps as a whole batch is best
            # TODO: Introduce a key-value store (encapsulate instead of inherit)
            self.keep_inputs("lm_pos_fpv", b_lm_pos_fpv)
            self.keep_inputs("lm_indices", b_lm_indices)
            self.keep_inputs("lm_mentioned", b_lm_mentioned)

            # TODO: Abstract all of these if-elses in a modular way once we know which ones are necessary
            if templates:
                b_lm_mentioned_tplt = lm_mentioned_tplt[b][:b_seq_len]
                b_side_mentioned_tplt = side_mentioned_tplt[b][:b_seq_len]
                b_side_mentioned_tplt = self.cuda_var(b_side_mentioned_tplt)
                b_lm_mentioned_tplt = self.cuda_var(b_lm_mentioned_tplt)
                self.keep_inputs("lm_mentioned_tplt", b_lm_mentioned_tplt)
                self.keep_inputs("side_mentioned_tplt", b_side_mentioned_tplt)
            else:
                b_lang_lm_mentioned = self.cuda_var(lang_lm_mentioned[b][:b_seq_len])
                self.keep_inputs("lang_lm_mentioned", b_lang_lm_mentioned)


            # ----------------------------------------------------------------------------

            self.prof.tick("inputs")

            actions = self(b_images, b_states, b_instructions, b_instr_len)

            action_losses, _ = self.action_loss(b_actions, actions, batchreduce=False)

            self.prof.tick("call")
            action_losses = self.action_loss.batch_reduce_loss(action_losses)
            action_loss = self.action_loss.reduce_loss(action_losses)
            action_loss_total = action_loss
            count += b_seq_len

            self.prof.tick("loss")

        action_loss_avg = action_loss_total / (count + 1e-9)

        self.prof.tick("out")

        # Doing this in the end (outside of se
        aux_losses = self.calculate_aux_loss(reduce_average=True)
        aux_loss = self.combine_aux_losses(aux_losses, self.aux_weights)

        prefix = self.model_name + ("/eval" if eval else "/train")

        self.writer.add_dict(prefix, get_current_meters(), self.get_iter())
        self.writer.add_dict(prefix, aux_losses, self.get_iter())
        self.writer.add_scalar(prefix + "/action_loss", action_loss_avg.data.cpu()[0], self.get_iter())

        self.prof.tick("auxiliaries")

        total_loss = action_loss_avg + aux_loss

        self.inc_iter()

        self.prof.tick("summaries")
        self.prof.loop()
        self.prof.print_stats(1)

        return total_loss