Example #1
0
    def gen_neg_instructions(self, env_id, seg_idx):
        # If we are to be using similar instructions according to the json file, then
        # initialize choices with similar instructions. Otherwise let choices be empty, and they will
        # be filled in the following lines.
        if self.instr_negatives_similar_only:
            choices = self.similar_instruction_map[str(env_id)][str(seg_idx)]
        else:
            choices = []
        # If there are no similar instructions to this instruction, pick a completely random instruction
        if len(choices) == 0:
            while len(choices) == 0:
                env_options = list(self.similar_instruction_map.keys())
                random_env = random.choice(env_options)
                seg_options = list(self.similar_instruction_map[random_env].keys())
                if len(seg_options) == 0:
                    continue
                random_seg = random.choice(seg_options)
                choices = self.similar_instruction_map[random_env][random_seg]

        pick = random.choice(choices)
        picked_env = pick["env_id"]
        picked_seg = pick["seg_idx"]
        picked_set = pick["set_idx"]
        picked_instruction = self.all_instr[picked_env][picked_set]["instructions"][picked_seg]["instruction"]
        tok_fake_instruction = tokenize_instruction(picked_instruction, self.word2token)
        return torch.LongTensor(tok_fake_instruction).unsqueeze(0)
Example #2
0
def analyze_instruction_set(name, iset, corpus, merge_len):
    token_lengths = []
    demo_lengths = []
    token2word, word2token = get_word_to_token_map(corpus)

    for e, instr_sets in iset.items():
        segs = instr_sets[0]["instructions"]
        if len(segs) > 0:
            full_path = load_path(e)
        for seg in segs:
            if seg["merge_len"] != merge_len:
                continue
            tok_i = tokenize_instruction(seg['instruction'], word2token)
            start_idx = seg["start_idx"]
            end_idx = seg["end_idx"]
            seg_path = full_path[start_idx:end_idx]
            demo_len = path_length(seg_path)

            demo_lengths.append(demo_len)
            token_lengths.append(len(tok_i))

    avg_tok_len = sum(token_lengths) / len(token_lengths)
    avg_pth_len = sum(demo_lengths) * 4.7 / (len(demo_lengths) * 1000)

    print("Dataset: ", name)
    print(" {}  &  {}  &  {:.2f}  &  {:.2f}".format(len(iset), len(token_lengths), avg_tok_len, avg_pth_len))
Example #3
0
    def gen_instruction(self, instruction):
        tok_instruction = tokenize_instruction(instruction, self.word2token)
        instruction_t = torch.LongTensor(tok_instruction)

        # If we're doing segment level, we want to support batching later on.
        # Otherwise each instance is a batch in itself
        # TODO Move unsqueezing into the collate_fn
        if not self.seg_level:
            instruction_t = instruction_t.unsqueeze(0)
        return instruction_t
Example #4
0
    def __getitem__(self, idx):
        if self.seg_level:
            env_id = self.seg_list[idx][0]
            set_idx = self.seg_list[idx][1]
            seg_idx = self.seg_list[idx][2]
        else:
            env_id = self.env_list[idx]

        print("top_down_dataset_sm __getitem__ load_env_config")
        env_conf_json = load_env_config(env_id)
        landmark_names, landmark_indices, landmark_positions = get_landmark_locations_airsim(env_conf_json)

        top_down_image = load_env_img(env_id)

        path = load_path(env_id)

        img_x = top_down_image.shape[0]
        img_y = top_down_image.shape[1]

        path_in_img_coords = self.cf_to_img(img_x, path)
        landmark_pos_in_img = self.as_to_img(img_x, np.asarray(landmark_positions)[:, 0:2])
        self.pos_rand_image = self.pos_rand_range * img_x

        #self.plot_path_on_img(top_down_image, path_in_img_coords)
        #self.plot_path_on_img(top_down_image, landmark_pos_in_img)
        #cv2.imshow("top_down", top_down_image)
        #cv2.waitKey()

        input_images = []
        input_instructions = []
        label_images = []
        aux_labels = []

        # Somehow load the instruction with the start and end indices for each of the N segments
        if self.seg_level:
            instruction_segments = [self.all_instr[env_id][set_idx]["instructions"][seg_idx]]
        else:
            instruction_segments = self.all_instr[env_id][0]["instructions"]

        for seg_idx, seg in enumerate(instruction_segments):
            start_idx = seg["start_idx"]
            end_idx = seg["end_idx"]
            instruction = seg["instruction"]

            # TODO: Check for overflowz
            seg_path = path_in_img_coords[start_idx:end_idx]
            seg_img = top_down_image.copy()

            #test_plot = self.plot_path_on_img(seg_img, seg_path)
            # TODO: Validate the 0.5 choice, should it be 2?
            affine, cropsize = self.get_affine_matrix(seg_path, 0, [int(img_x / 2), int(img_y / 2)], 0.5)
            if affine is None:
                continue
            seg_img_rot = self.apply_affine(seg_img, affine, cropsize)

            seg_labels = np.zeros_like(seg_img[:, :, 0:1]).astype(float)
            seg_labels = self.plot_path_on_img(seg_labels, seg_path)
            seg_labels = gaussian_filter(seg_labels, 4)
            seg_labels_rot = self.apply_affine(seg_labels, affine, cropsize)

            #seg_labels_rot = gaussian_filter(seg_labels_rot, 4)
            seg_labels_rot = self.normalize_0_1(seg_labels_rot)

            # Change to true to visualize the paths / labels
            if False:
                cv2.imshow("rot_img", seg_img_rot)
                cv2.imshow("seg_labels", seg_labels_rot)
                rot_viz = seg_img_rot.astype(np.float64) / 512
                rot_viz[:, :, 0] += seg_labels_rot.squeeze()
                cv2.imshow("rot_viz", rot_viz)
                cv2.waitKey(0)

            tok_instruction = tokenize_instruction(instruction, self.word2token)
            instruction_t = torch.LongTensor(tok_instruction).unsqueeze(0)

            # Get landmark classification labels
            landmark_pos_in_seg_img = self.apply_affine_on_pts(landmark_pos_in_img, affine)

            # Down-size images and labels if requested by the model
            if self.img_scale != 1.0:
                seg_img_rot = transform.resize(
                    seg_img_rot,
                    [seg_img_rot.shape[0] * self.img_scale,
                     seg_img_rot.shape[1] * self.img_scale], mode="constant")
                seg_labels_rot = transform.resize(
                    seg_labels_rot,
                    [seg_labels_rot.shape[0] * self.img_scale,
                     seg_labels_rot.shape[1] * self.img_scale], mode="constant")
                landmark_pos_in_seg_img = landmark_pos_in_seg_img * self.img_scale

            seg_img_rot = standardize_image(seg_img_rot)
            seg_labels_rot = standardize_image(seg_labels_rot)
            seg_img_t = torch.from_numpy(seg_img_rot).unsqueeze(0).float()
            seg_labels_t = torch.from_numpy(seg_labels_rot).unsqueeze(0).float()

            landmark_pos_t = torch.from_numpy(landmark_pos_in_seg_img).unsqueeze(0)
            landmark_indices_t = torch.LongTensor(landmark_indices).unsqueeze(0)

            mask1 = torch.gt(landmark_pos_t, 0)
            mask2 = torch.lt(landmark_pos_t, seg_img_t.size(2))
            mask = mask1 * mask2
            mask = mask[:, :, 0] * mask[:, :, 1]
            mask = mask

            landmark_pos_t = torch.masked_select(landmark_pos_t, mask.unsqueeze(2).expand_as(landmark_pos_t)).view([-1, 2])
            landmark_indices_t = torch.masked_select(landmark_indices_t, mask).view([-1])

            mentioned_names, mentioned_indices = get_mentioned_landmarks(self.thesaurus, instruction)
            mentioned_labels_t = empty_float_tensor(list(landmark_indices_t.size())).long()
            for i, landmark_idx_present in enumerate(landmark_indices_t):
                if landmark_idx_present in mentioned_indices:
                    mentioned_labels_t[i] = 1

            aux_label = {
                "landmark_pos": landmark_pos_t,
                "landmark_indices": landmark_indices_t,
                "landmark_mentioned": mentioned_labels_t,
                "visible_mask": mask,
            }

            if self.include_instr_negatives:
                # If we are to be using similar instructions according to the json file, then
                # initialize choices with similar instructions. Otherwise let choices be empty, and they will
                # be filled in the following lines.
                if self.instr_negatives_similar_only:
                    choices = self.similar_instruction_map[str(env_id)][str(seg_idx)]
                else:
                    choices = []
                # If there are no similar instructions to this instruction, pick a completely random instruction
                if len(choices) == 0:
                    while len(choices) == 0:
                        env_options = list(self.similar_instruction_map.keys())
                        random_env = random.choice(env_options)
                        seg_options = list(self.similar_instruction_map[random_env].keys())
                        if len(seg_options) == 0:
                            continue
                        random_seg = random.choice(seg_options)
                        choices = self.similar_instruction_map[random_env][random_seg]

                pick = random.choice(choices)
                picked_env = pick["env_id"]
                picked_seg = pick["seg_idx"]
                picked_set = pick["set_idx"]
                picked_instruction = self.all_instr[picked_env][picked_set]["instructions"][picked_seg]["instruction"]
                tok_fake_instruction = tokenize_instruction(picked_instruction, self.word2token)
                aux_label["negative_instruction"] = torch.LongTensor(tok_fake_instruction).unsqueeze(0)

            input_images.append(seg_img_t)
            input_instructions.append(instruction_t)
            label_images.append(seg_labels_t)
            aux_labels.append(aux_label)

        return [input_images, input_instructions, label_images, aux_labels]
Example #5
0
def interactive_demo():

    P.initialize_experiment()
    InteractAPI.launch_ui()

    rate = Rate(0.1)

    env = PomdpInterface(
        is_real=get_current_parameters()["Setup"]["real_drone"])
    train_instructions, dev_instructions, test_instructions, corpus = get_all_instructions(
    )
    all_instr = {
        **train_instructions,
        **dev_instructions,
        **train_instructions
    }
    token2term, word2token = get_word_to_token_map(corpus)

    # Run on dev set
    interact_instructions = dev_instructions

    env_range_start = get_current_parameters()["Setup"].get(
        "env_range_start", 0)
    env_range_end = get_current_parameters()["Setup"].get(
        "env_range_end", 10e10)
    interact_instructions = {
        k: v
        for k, v in interact_instructions.items()
        if env_range_start < k < env_range_end
    }

    count = 0
    stuck_count = 0

    model, _ = load_model(get_current_parameters()["Setup"]["model"])

    InteractAPI.write_empty_instruction()
    InteractAPI.write_real_instruction("None")
    instruction_str = InteractAPI.read_instruction_file()
    print("Initial instruction: ", instruction_str)

    for instruction_sets in interact_instructions.values():
        for set_idx, instruction_set in enumerate(instruction_sets):
            env_id = instruction_set['env']
            env.set_environment(env_id, instruction_set["instructions"])

            presenter = Presenter()
            cumulative_reward = 0
            for seg_idx in range(len(instruction_set["instructions"])):

                print(f"RUNNING ENV {env_id} SEG {seg_idx}")

                real_instruction_str = instruction_set["instructions"][
                    seg_idx]["instruction"]
                InteractAPI.write_real_instruction(real_instruction_str)
                valid_segment = env.set_current_segment(seg_idx)
                if not valid_segment:
                    continue
                state = env.reset(seg_idx)

                keep_going = True
                while keep_going:
                    InteractAPI.write_real_instruction(real_instruction_str)

                    while True:
                        cv2.waitKey(200)
                        instruction = InteractAPI.read_instruction_file()
                        if instruction == "CMD: Next":
                            print("Advancing")
                            keep_going = False
                            InteractAPI.write_empty_instruction()
                            break
                        elif instruction == "CMD: Reset":
                            print("Resetting")
                            env.reset(seg_idx)
                            InteractAPI.write_empty_instruction()
                        elif len(instruction.split(" ")) > 1:
                            instruction_str = instruction
                            break

                    if not keep_going:
                        continue

                    env.override_instruction(instruction_str)
                    tok_instruction = tokenize_instruction(
                        instruction_str, word2token)

                    state = env.reset(seg_idx)
                    print("Executing: f{instruction_str}")
                    while True:
                        rate.sleep()
                        action, internals = model.get_action(
                            state, tok_instruction)

                        state, reward, done, expired, oob = env.step(action)
                        cumulative_reward += reward
                        presenter.show_sample(state, action, reward,
                                              cumulative_reward,
                                              instruction_str)
                        #show_depth(state.image)
                        if done:
                            break
                    InteractAPI.write_empty_instruction()
                    print("Segment finished!")
        print("Env finished!")
Example #6
0
def train_top_down_pred():
    P.initialize_experiment()
    setup = P.get_current_parameters()["Setup"]
    launch_ui()

    env = PomdpInterface()

    print("model_name:", setup["top_down_model"])
    print("model_file:", setup["top_down_model_file"])

    model, model_loaded = load_model(
        model_name_override=setup["top_down_model"],
        model_file_override=setup["top_down_model_file"])

    exec_model, wrapper_model_loaded = load_model(
        model_name_override=setup["wrapper_model"],
        model_file_override=setup["wrapper_model_file"])

    affine2d = Affine2D()
    if model.is_cuda:
        affine2d.cuda()

    eval_envs = get_correct_eval_env_id_list()
    print("eval_envs:", eval_envs)
    train_instructions, dev_instructions, test_instructions, corpus = get_all_instructions(
        max_size=setup["max_envs"])
    all_instr = {
        **train_instructions,
        **dev_instructions,
        **train_instructions
    }
    token2term, word2token = get_word_to_token_map(corpus)

    dataset = model.get_dataset(envs=eval_envs,
                                dataset_name="supervised",
                                eval=True,
                                seg_level=False)
    dataloader = DataLoader(dataset,
                            collate_fn=dataset.collate_fn,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1,
                            pin_memory=True)

    for b, batch in list(enumerate(dataloader)):
        print("batch:", batch)
        images = batch["images"]
        instructions = batch["instr"]
        label_masks = batch["traj_labels"]
        affines = batch["affines_g_to_s"]
        env_ids = batch["env_id"]
        set_idxs = batch["set_idx"]
        seg_idxs = batch["seg_idx"]

        env_id = env_ids[0][0]
        set_idx = set_idxs[0][0]
        print("env_id of this batch:", env_id)
        env.set_environment(
            env_id, instruction_set=all_instr[env_id][set_idx]["instructions"])
        env.reset(0)

        num_segments = len(instructions[0])
        print("num_segments in this batch:", num_segments)
        write_instruction("")
        write_real_instruction("None")
        instruction_str = read_instruction_file()
        print("Initial instruction: ", instruction_str)

        # TODO: Reset model state here if we keep any temporal memory etc
        for s in range(num_segments):
            start_state = env.reset(s)
            keep_going = True
            real_instruction = cuda_var(instructions[0][s], setup["cuda"], 0)
            tmp = list(real_instruction.data.cpu()[0].numpy())
            real_instruction_str = debug_untokenize_instruction(tmp)
            write_real_instruction(real_instruction_str)
            #write_instruction(real_instruction_str)
            #instruction_str = real_instruction_str

            image = cuda_var(images[0][s], setup["cuda"], 0)
            label_mask = cuda_var(label_masks[0][s], setup["cuda"], 0)
            affine_g_to_s = affines[0][s]
            print("Your current environment:")
            with open(
                    "/storage/dxsun/unreal_config_nl/configs/configs/random_config_"
                    + str(env_id) + ".json") as fp:
                config = json.load(fp)
            print(config)
            while keep_going:
                write_real_instruction(real_instruction_str)

                while True:
                    cv2.waitKey(200)
                    instruction = read_instruction_file()
                    if instruction == "CMD: Next":
                        print("Advancing")
                        keep_going = False
                        write_empty_instruction()
                        break
                    elif instruction == "CMD: Reset":
                        print("Resetting")
                        env.reset(s)
                        write_empty_instruction()
                    elif len(instruction.split(" ")) > 1:
                        instruction_str = instruction
                        print("Executing: ", instruction_str)
                        break

                if not keep_going:
                    continue

                #instruction_str = read_instruction_file()
                # TODO: Load instruction from file
                tok_instruction = tokenize_instruction(instruction_str,
                                                       word2token)
                instruction_t = torch.LongTensor(tok_instruction).unsqueeze(0)
                instruction_v = cuda_var(instruction_t, setup["cuda"], 0)
                instruction_mask = torch.ones_like(instruction_v)
                tmp = list(instruction_t[0].numpy())
                instruction_dbg_str = debug_untokenize_instruction(
                    tmp, token2term)

                # import matplotlib.pyplot as plt
                #plt.plot(image.squeeze(0).permute(1,2,0).cpu().numpy())
                #plt.show()

                res = model(image, instruction_v, instruction_mask)
                mask_pred = res[0]
                shp = mask_pred.shape
                mask_pred = F.softmax(mask_pred.view([2, -1]), 1).view(shp)
                #mask_pred = softmax2d(mask_pred)

                # TODO: Rotate the mask_pred to the global frame
                affine_s_to_g = np.linalg.inv(affine_g_to_s)
                S = 8.0
                affine_scale_up = np.asarray([[S, 0, 0], [0, S, 0], [0, 0, 1]])
                affine_scale_down = np.linalg.inv(affine_scale_up)

                affine_pred_to_g = np.dot(
                    affine_scale_down, np.dot(affine_s_to_g, affine_scale_up))
                #affine_pred_to_g_t = torch.from_numpy(affine_pred_to_g).float()

                mask_pred_np = mask_pred.data.cpu().numpy()[0].transpose(
                    1, 2, 0)
                mask_pred_g_np = apply_affine(mask_pred_np, affine_pred_to_g,
                                              32, 32)
                print("Sum of global mask: ", mask_pred_g_np.sum())
                mask_pred_g = torch.from_numpy(
                    mask_pred_g_np.transpose(2, 0,
                                             1)).float()[np.newaxis, :, :, :]
                exec_model.set_ground_truth_visitation_d(mask_pred_g)

                # Create a batch axis for pytorch
                #mask_pred_g = affine2d(mask_pred, affine_pred_to_g_t[np.newaxis, :, :])

                mask_pred_np[:, :, 0] -= mask_pred_np[:, :, 0].min()
                mask_pred_np[:, :, 0] /= (mask_pred_np[:, :, 0].max() + 1e-9)
                mask_pred_np[:, :, 0] *= 2.0
                mask_pred_np[:, :, 1] -= mask_pred_np[:, :, 1].min()
                mask_pred_np[:, :, 1] /= (mask_pred_np[:, :, 1].max() + 1e-9)

                presenter = Presenter()
                presenter.show_image(mask_pred_g_np,
                                     "mask_pred_g",
                                     torch=False,
                                     waitkey=1,
                                     scale=4)
                #import matplotlib.pyplot as plt
                #print("image.data shape:", image.data.cpu().numpy().shape)
                #plt.imshow(image.data.squeeze().permute(1,2,0).cpu().numpy())
                #plt.show()
                # presenter.show_image(image.data, "mask_pred_g", torch=False, waitkey=1, scale=4)
                #import pdb; pdb.set_trace()
                pred_viz_np = presenter.overlaid_image(image.data,
                                                       mask_pred_np,
                                                       channel=0)
                # TODO: Don't show labels
                # TODO: OpenCV colours
                #label_mask_np = p.data.cpu().numpy()[0].transpose(1,2,0)
                labl_viz_np = presenter.overlaid_image(image.data,
                                                       label_mask.data,
                                                       channel=0)
                viz_img_np = np.concatenate((pred_viz_np, labl_viz_np), axis=1)
                viz_img_np = pred_viz_np

                viz_img = presenter.overlay_text(viz_img_np,
                                                 instruction_dbg_str)
                cv2.imshow("interactive viz", viz_img)
                cv2.waitKey(100)

                rollout_model(exec_model, env, env_ids[0][s], set_idxs[0][s],
                              seg_idxs[0][s], tok_instruction)
                write_instruction("")
Example #7
0
def automatic_demo():

    P.initialize_experiment()
    instruction_display = InstructionDisplay()

    rate = Rate(0.1)

    env = PomdpInterface(
        is_real=get_current_parameters()["Setup"]["real_drone"])
    train_instructions, dev_instructions, test_instructions, corpus = get_all_instructions(
    )
    all_instr = {
        **train_instructions,
        **dev_instructions,
        **train_instructions
    }
    token2term, word2token = get_word_to_token_map(corpus)

    # Run on dev set
    interact_instructions = dev_instructions

    env_range_start = get_current_parameters()["Setup"].get(
        "env_range_start", 0)
    env_range_end = get_current_parameters()["Setup"].get(
        "env_range_end", 10e10)
    interact_instructions = {
        k: v
        for k, v in interact_instructions.items()
        if env_range_start < k < env_range_end
    }

    model, _ = load_model(get_current_parameters()["Setup"]["model"])

    # Loop over the select few examples
    while True:

        for instruction_sets in interact_instructions.values():
            for set_idx, instruction_set in enumerate(instruction_sets):
                env_id = instruction_set['env']
                found_example = None
                for example in examples:
                    if example[0] == env_id:
                        found_example = example
                if found_example is None:
                    continue
                env.set_environment(env_id, instruction_set["instructions"])

                presenter = Presenter()
                cumulative_reward = 0
                for seg_idx in range(len(instruction_set["instructions"])):
                    if seg_idx != found_example[2]:
                        continue

                    print(f"RUNNING ENV {env_id} SEG {seg_idx}")

                    real_instruction_str = instruction_set["instructions"][
                        seg_idx]["instruction"]
                    instruction_display.show_instruction(real_instruction_str)
                    valid_segment = env.set_current_segment(seg_idx)
                    if not valid_segment:
                        continue
                    state = env.reset(seg_idx)

                    for i in range(START_PAUSE):
                        instruction_display.tick()
                        time.sleep(1)

                        tok_instruction = tokenize_instruction(
                            real_instruction_str, word2token)

                    state = env.reset(seg_idx)
                    print("Executing: f{instruction_str}")
                    while True:
                        instruction_display.tick()
                        rate.sleep()
                        action, internals = model.get_action(
                            state, tok_instruction)
                        state, reward, done, expired, oob = env.step(action)
                        cumulative_reward += reward
                        #presenter.show_sample(state, action, reward, cumulative_reward, real_instruction_str)
                        #show_depth(state.image)
                        if done:
                            break

                    for i in range(END_PAUSE):
                        instruction_display.tick()
                        time.sleep(1)
                        print("Segment finished!")
                    instruction_display.show_instruction("...")

            print("Env finished!")
Example #8
0
    def __getitem__(self, idx):
        self.prof.tick("out")
        # If data is already loaded, use it
        if self.data is not None:
            seg_data = self.data[idx]
            raise NotImplementedError("Not implemented and tested")
            if type(seg_data) is int:
                raise NotImplementedError("Mixing dynamically loaded envs with training data is no longer supported.")
        else:
            dataset_name, env_id, seg_idx = self.sample_ids[idx]
            env_data = self.load_env_data(dataset_name, env_id)

            if self.segment_level:
                seg_data = []
                segs_in_data = set()
                for sample in env_data:
                    # This is a hack around the dataset format change - some stuff used to be inside the metadata dict,
                    # but is now moved into the root level
                    if "metadata" not in sample:
                        sample["metadata"] = sample
                    # TODO: Set this at rollout time - we know which domain we're rolling out, but this can potentially be mixed up
                    sample["metadata"]["domain"] = self.domain
                    segs_in_data.add(sample["metadata"]["seg_idx"])

                # Keep the segments for which we have instructions
                segs_in_data_and_instructions = set()
                for _seg_idx in segs_in_data:
                    if get_instruction_segment(env_id, 0, _seg_idx, all_instr=self.all_instr_full) is not None:
                        segs_in_data_and_instructions.add(_seg_idx)

                if seg_idx not in segs_in_data_and_instructions:
                    if DEBUG: print(f"Segment {env_id}::{seg_idx} not in (data)and(instructions)")
                    # If there's a single segment in this entire dataset, just return that segment even if it's not a match.
                    if len(segs_in_data) == 1:
                        seg_data = env_data
                        if DEBUG: print(f"  Only one seg in data ({segs_in_data}): returning that")
                    # Otherwise return a random segment instead
                    elif len(segs_in_data_and_instructions) > 0:
                        seg_idx = random.choice(list(segs_in_data_and_instructions))
                        if DEBUG: print(f"  Returning a random segment from (data)and(instructions): {seg_idx}")
                    elif dataset_name == "real" and len(segs_in_data) > 0:
                        seg_idx = random.choice(list(segs_in_data))
                        if DEBUG: print(f"  REAL dataset. Returning a random seg from data: {seg_idx}")
                    else:
                        seg_idx = -1
                        if DEBUG: print(f"  No segment found. Skipping example")

                if len(seg_data) == 0:
                    if DEBUG: print(f"   Grabing segment: {seg_idx}")
                    for sample in env_data:
                        if sample["metadata"]["seg_idx"] == seg_idx:
                            seg_data.append(sample)
                if DEBUG: print(f"   Returning segment data of length: {len(seg_data)}")
            else:
                seg_data = env_data
        # I get a lot of Nones here in RL training because the dataset index is created based on different data than available!
        # TODO: in RL training, treat entire environment as a single segment and don't distinguish.
        # How? Check above
        if len(seg_data) < self.min_seg_len:
            print(f"   None reason: len:{len(seg_data)} in {dataset_name}, env:{env_id}, seg:{seg_idx}")
            return None

        if len(seg_data) > self.traj_len:
            seg_data = seg_data[:self.traj_len]

        seg_idx = seg_data[0]["metadata"]["seg_idx"]
        set_idx = seg_data[0]["metadata"]["set_idx"]
        env_id = seg_data[0]["metadata"]["env_id"]
        instr = get_instruction_segment(env_id, set_idx, seg_idx, all_instr=self.all_instr)
        if instr is None and dataset_name != "real":
            #print(f"{dataset_name} Seg {env_id}:{set_idx}:{seg_idx} not present in instruction data")
            return None

        instr = get_instruction_segment(env_id, set_idx, seg_idx, all_instr=self.all_instr_full)
        if instr is None:
            print(f"{dataset_name} Seg {env_id}:{set_idx}:{seg_idx} not present in FULL instruction data. WTF?")
            return None

        # Convert to tensors, replacing Nones with zero's
        images_in = [seg_data[i]["state"].image if i < len(seg_data) else None for i in range(len(seg_data))]
        states = [seg_data[i]["state"].state if i < len(seg_data) else None for i in range(len(seg_data))]

        images_np = standardize_images(images_in)
        images = none_padded_seq_to_tensor(images_np)

        #depth_images_np = standardize_depth_images(images_in)
        #depth_images = none_padded_seq_to_tensor(depth_images_np)

        states = none_padded_seq_to_tensor(states)

        actions = [s["ref_action"] for s in seg_data]
        actions = none_padded_seq_to_tensor(actions)
        stops = [1.0 if s["done"] else 0.0 for s in seg_data]

        # e.g. [1 1 1 1 1 1 0 0 0 0 .. 0] for segment with 6 samples
        mask = [1.0 if s["ref_action"] is not None else 0.0 for s in seg_data]

        stops = torch.FloatTensor(stops)
        mask = torch.FloatTensor(mask)

        # This is a list, converted to tensor in collate_fn
        #if INSTRUCTIONS_FROM_FILE:
        #    tok_instructions = [tokenize_instruction(load_instruction(md["env_id"], md["set_idx"], md["seg_idx"]), self.word2token) if s["md"] is not None else None for s in seg_data]
        #else:
        tok_instructions = [tokenize_instruction(s["instruction"], self.word2token) if s["instruction"] is not None else None for s in seg_data]

        md = [seg_data[i]["metadata"] for i in range(len(seg_data))]
        flag = md[0]["flag"] if "flag" in md[0] else None

        data = {
            "instr": tok_instructions,
            "images": images,
            #"depth_images": depth_images,
            "states": states,
            "actions": actions,
            "stops": stops,
            "masks": mask,
            "flags": flag,
            "md": md
        }

        self.prof.tick("getitem_core")
        for aux_provider_name in self.aux_provider_names:
            aux_datas = resolve_data_provider(aux_provider_name)(seg_data, data)
            for d in aux_datas:
                data[d[0]] = d[1]
            self.prof.tick("getitem_" + aux_provider_name)

        return data