Exemple #1
0
    def forward(self, image_g, pose):

        self.set_map(image_g, None)
        image_r, _ = self.get_map(pose)

        presenter = Presenter()
        presenter.show_image(image_g[0].data,
                             "img_g",
                             torch=True,
                             waitkey=False,
                             scale=2)
        presenter.show_image(image_r[0].data,
                             "img_r",
                             torch=True,
                             waitkey=100,
                             scale=2)

        features_r = self.feature_net(image_r)
        coverage = torch.ones_like(features_r)
        return features_r, coverage, image_r
Exemple #2
0
def browse_pvn_dataset():
    P.initialize_experiment()

    setup = P.get_current_parameters()["Setup"]
    model_sim, _ = load_model(setup["model"],
                              setup["sim_model_file"],
                              domain="sim")
    data_params = P.get_current_parameters()["Training"]

    print("Loading data")
    train_envs, dev_envs, test_envs = get_restricted_env_id_lists()

    #dom="real"
    dom = "sim"

    dataset = model_sim.get_dataset(
        data=None,
        envs=train_envs,
        domain=dom,
        dataset_names=data_params[f"{dom}_dataset_names"],
        dataset_prefix="supervised",
        eval=False,
        halfway_only=False)

    p = Presenter()

    for example in dataset:
        if example is None:
            continue
        md = example["md"][0]
        print(
            f"Showing example: {md['env_id']}:{md['set_idx']}:{md['seg_idx']}")
        print(f"  instruction: {md['instruction']}")
        exec_len = len(example["images"])
        for i in range(exec_len):
            print(f"   timestep: {i}")
            img_i = example["images"][i]
            lm_fpv_i = example["lm_pos_fpv"][i]
            if lm_fpv_i is not None:
                img_i = p.plot_pts_on_torch_image(img_i, lm_fpv_i.long())
            p.show_image(img_i, "fpv_img_i", scale=4, waitkey=True)
def train_top_down_pred():
    P.initialize_experiment()
    setup = P.get_current_parameters()["Setup"]
    launch_ui()

    env = PomdpInterface()

    print("model_name:", setup["top_down_model"])
    print("model_file:", setup["top_down_model_file"])

    model, model_loaded = load_model(
        model_name_override=setup["top_down_model"],
        model_file_override=setup["top_down_model_file"])

    exec_model, wrapper_model_loaded = load_model(
        model_name_override=setup["wrapper_model"],
        model_file_override=setup["wrapper_model_file"])

    affine2d = Affine2D()
    if model.is_cuda:
        affine2d.cuda()

    eval_envs = get_correct_eval_env_id_list()
    print("eval_envs:", eval_envs)
    train_instructions, dev_instructions, test_instructions, corpus = get_all_instructions(
        max_size=setup["max_envs"])
    all_instr = {
        **train_instructions,
        **dev_instructions,
        **train_instructions
    }
    token2term, word2token = get_word_to_token_map(corpus)

    dataset = model.get_dataset(envs=eval_envs,
                                dataset_name="supervised",
                                eval=True,
                                seg_level=False)
    dataloader = DataLoader(dataset,
                            collate_fn=dataset.collate_fn,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1,
                            pin_memory=True)

    for b, batch in list(enumerate(dataloader)):
        print("batch:", batch)
        images = batch["images"]
        instructions = batch["instr"]
        label_masks = batch["traj_labels"]
        affines = batch["affines_g_to_s"]
        env_ids = batch["env_id"]
        set_idxs = batch["set_idx"]
        seg_idxs = batch["seg_idx"]

        env_id = env_ids[0][0]
        set_idx = set_idxs[0][0]
        print("env_id of this batch:", env_id)
        env.set_environment(
            env_id, instruction_set=all_instr[env_id][set_idx]["instructions"])
        env.reset(0)

        num_segments = len(instructions[0])
        print("num_segments in this batch:", num_segments)
        write_instruction("")
        write_real_instruction("None")
        instruction_str = read_instruction_file()
        print("Initial instruction: ", instruction_str)

        # TODO: Reset model state here if we keep any temporal memory etc
        for s in range(num_segments):
            start_state = env.reset(s)
            keep_going = True
            real_instruction = cuda_var(instructions[0][s], setup["cuda"], 0)
            tmp = list(real_instruction.data.cpu()[0].numpy())
            real_instruction_str = debug_untokenize_instruction(tmp)
            write_real_instruction(real_instruction_str)
            #write_instruction(real_instruction_str)
            #instruction_str = real_instruction_str

            image = cuda_var(images[0][s], setup["cuda"], 0)
            label_mask = cuda_var(label_masks[0][s], setup["cuda"], 0)
            affine_g_to_s = affines[0][s]
            print("Your current environment:")
            with open(
                    "/storage/dxsun/unreal_config_nl/configs/configs/random_config_"
                    + str(env_id) + ".json") as fp:
                config = json.load(fp)
            print(config)
            while keep_going:
                write_real_instruction(real_instruction_str)

                while True:
                    cv2.waitKey(200)
                    instruction = read_instruction_file()
                    if instruction == "CMD: Next":
                        print("Advancing")
                        keep_going = False
                        write_empty_instruction()
                        break
                    elif instruction == "CMD: Reset":
                        print("Resetting")
                        env.reset(s)
                        write_empty_instruction()
                    elif len(instruction.split(" ")) > 1:
                        instruction_str = instruction
                        print("Executing: ", instruction_str)
                        break

                if not keep_going:
                    continue

                #instruction_str = read_instruction_file()
                # TODO: Load instruction from file
                tok_instruction = tokenize_instruction(instruction_str,
                                                       word2token)
                instruction_t = torch.LongTensor(tok_instruction).unsqueeze(0)
                instruction_v = cuda_var(instruction_t, setup["cuda"], 0)
                instruction_mask = torch.ones_like(instruction_v)
                tmp = list(instruction_t[0].numpy())
                instruction_dbg_str = debug_untokenize_instruction(
                    tmp, token2term)

                # import matplotlib.pyplot as plt
                #plt.plot(image.squeeze(0).permute(1,2,0).cpu().numpy())
                #plt.show()

                res = model(image, instruction_v, instruction_mask)
                mask_pred = res[0]
                shp = mask_pred.shape
                mask_pred = F.softmax(mask_pred.view([2, -1]), 1).view(shp)
                #mask_pred = softmax2d(mask_pred)

                # TODO: Rotate the mask_pred to the global frame
                affine_s_to_g = np.linalg.inv(affine_g_to_s)
                S = 8.0
                affine_scale_up = np.asarray([[S, 0, 0], [0, S, 0], [0, 0, 1]])
                affine_scale_down = np.linalg.inv(affine_scale_up)

                affine_pred_to_g = np.dot(
                    affine_scale_down, np.dot(affine_s_to_g, affine_scale_up))
                #affine_pred_to_g_t = torch.from_numpy(affine_pred_to_g).float()

                mask_pred_np = mask_pred.data.cpu().numpy()[0].transpose(
                    1, 2, 0)
                mask_pred_g_np = apply_affine(mask_pred_np, affine_pred_to_g,
                                              32, 32)
                print("Sum of global mask: ", mask_pred_g_np.sum())
                mask_pred_g = torch.from_numpy(
                    mask_pred_g_np.transpose(2, 0,
                                             1)).float()[np.newaxis, :, :, :]
                exec_model.set_ground_truth_visitation_d(mask_pred_g)

                # Create a batch axis for pytorch
                #mask_pred_g = affine2d(mask_pred, affine_pred_to_g_t[np.newaxis, :, :])

                mask_pred_np[:, :, 0] -= mask_pred_np[:, :, 0].min()
                mask_pred_np[:, :, 0] /= (mask_pred_np[:, :, 0].max() + 1e-9)
                mask_pred_np[:, :, 0] *= 2.0
                mask_pred_np[:, :, 1] -= mask_pred_np[:, :, 1].min()
                mask_pred_np[:, :, 1] /= (mask_pred_np[:, :, 1].max() + 1e-9)

                presenter = Presenter()
                presenter.show_image(mask_pred_g_np,
                                     "mask_pred_g",
                                     torch=False,
                                     waitkey=1,
                                     scale=4)
                #import matplotlib.pyplot as plt
                #print("image.data shape:", image.data.cpu().numpy().shape)
                #plt.imshow(image.data.squeeze().permute(1,2,0).cpu().numpy())
                #plt.show()
                # presenter.show_image(image.data, "mask_pred_g", torch=False, waitkey=1, scale=4)
                #import pdb; pdb.set_trace()
                pred_viz_np = presenter.overlaid_image(image.data,
                                                       mask_pred_np,
                                                       channel=0)
                # TODO: Don't show labels
                # TODO: OpenCV colours
                #label_mask_np = p.data.cpu().numpy()[0].transpose(1,2,0)
                labl_viz_np = presenter.overlaid_image(image.data,
                                                       label_mask.data,
                                                       channel=0)
                viz_img_np = np.concatenate((pred_viz_np, labl_viz_np), axis=1)
                viz_img_np = pred_viz_np

                viz_img = presenter.overlay_text(viz_img_np,
                                                 instruction_dbg_str)
                cv2.imshow("interactive viz", viz_img)
                cv2.waitKey(100)

                rollout_model(exec_model, env, env_ids[0][s], set_idxs[0][s],
                              seg_idxs[0][s], tok_instruction)
                write_instruction("")
env = PomdpInterface()

count = 0

presenter = Presenter()


def save_landmark_img(state, landmark_name, i, eval):
    data_dir = get_landmark_images_dir(landmark_name, eval)
    os.makedirs(data_dir, exist_ok=True)
    full_path = os.path.join(data_dir, landmark_name + "_" + str(i) + ".jpg")
    scipy.misc.imsave(full_path, state.image)


for landmark_name, landmark_radius in LANDMARK_RADII.items():

    for i in range(IMAGES_PER_LANDMARK_TRAIN):
        print("Generating train image " + str(i) + " for landmark '" +
              landmark_name + "'")
        state = env.reset_to_random_cv_env(landmark_name)
        presenter.show_image(state.image)
        save_landmark_img(state, landmark_name, i, False)

    for i in range(IMAGES_PER_LANDMARK_TEST):
        print("Generating test image " + str(i) + " for landmark '" +
              landmark_name + "'")
        state = env.reset_to_random_cv_env(landmark_name)
        presenter.show_image(state.image)
        save_landmark_img(state, landmark_name, i, True)
Exemple #5
0
    def show_landmark_locations(self, loop=True, states=None):
        # Show landmark locations in first-person images
        img_all = self.tensor_store.get("images")
        img_w_all = self.tensor_store.get("images_w")
        import rollout.run_metadata as md
        if md.IS_ROLLOUT:
            # TODO: Discard this and move this to PomdpInterface or something
            # (it's got nothing to do with the model)
            # load landmark positions from configs
            from data_io.env import load_env_config
            from learning.datasets.aux_data_providers import get_landmark_locations_airsim
            from learning.models.semantic_map.pinhole_camera_inv import PinholeCameraProjection
            projector = PinholeCameraProjection(
                map_size_px=self.params["global_map_size"],
                world_size_px=self.params["world_size_px"],
                world_size_m=self.params["world_size_m"],
                img_x=self.params["img_w"],
                img_y=self.params["img_h"],
                cam_fov=self.params["cam_h_fov"],
                #TODO: Handle correctly
                domain="sim",
                use_depth=False)
            conf_json = load_env_config(md.ENV_ID)
            landmark_names, landmark_indices, landmark_pos = get_landmark_locations_airsim(
                conf_json)
            cam_poses = self.cam_poses_from_states(states)
            cam_pos = cam_poses.position[0]
            cam_rot = cam_poses.orientation[0]
            lm_pos_map_all = []
            lm_pos_img_all = []
            for i, landmark_in_world in enumerate(landmark_pos):
                lm_pos_img, landmark_in_cam, status = projector.world_point_to_image(
                    cam_pos, cam_rot, landmark_in_world)
                lm_pos_map = torch.from_numpy(
                    transformations.pos_m_to_px(
                        landmark_in_world[np.newaxis, :],
                        self.params["global_map_size"],
                        self.params["world_size_m"],
                        self.params["world_size_px"]))
                lm_pos_map_all += [lm_pos_map[0]]
                if lm_pos_img is not None:
                    lm_pos_img_all += [lm_pos_img]

            lm_pos_img_all = [lm_pos_img_all]
            lm_pos_map_all = [lm_pos_map_all]

        else:
            lm_pos_img_all = self.tensor_store.get("lm_pos_fpv_img")
            lm_pos_map_all = self.tensor_store.get("lm_pos_map")

        print("Plotting landmark points")

        for i in range(len(img_all)):
            p = Presenter()
            overlay_fpv = p.overlay_pts_on_image(img_all[i][0],
                                                 lm_pos_img_all[i])
            overlay_map = p.overlay_pts_on_image(img_w_all[i][0],
                                                 lm_pos_map_all[i])
            p.show_image(overlay_fpv, "landmarks_on_fpv_img", scale=8)
            p.show_image(overlay_map, "landmarks_on_map", scale=20)

            if not loop:
                break
Exemple #6
0
    def sup_loss_on_batch(self, batch, eval=False, viz=False):

        if eval:
            self.eval()
        else:
            self.train()

        images = cuda_var(batch["images"], self.is_cuda, self.cuda_device)
        instructions = cuda_var(batch["instr"], self.is_cuda, self.cuda_device)
        instruction_masks = cuda_var(batch["instr_mask"], self.is_cuda,
                                     self.cuda_device)
        label_masks = cuda_var(batch["traj_labels"], self.is_cuda,
                               self.cuda_device)

        # Each of the above is a list of lists of tensors, where the outer list is over the batch and the inner list
        # is over the segments. Loop through and accumulate loss for each batch sequentially, and for each segment.
        # Reset model state (embedding etc) between batches, but not between segments.
        # We don't process each batch in batch-mode, because it's complicated, with the varying number of segments and all.

        batch_size = len(images)
        total_class_loss = Variable(empty_float_tensor([1], self.is_cuda,
                                                       self.cuda_device),
                                    requires_grad=True)
        total_ground_loss = Variable(empty_float_tensor([1], self.is_cuda,
                                                        self.cuda_device),
                                     requires_grad=True)
        count = 0

        label_masks = self.label_pool(label_masks)
        mask_pred, features, emb_loss = self(images, instructions,
                                             instruction_masks)

        if BCE:
            mask_pred_flat = mask_pred.view(-1, 1)
            label_masks_flat = label_masks - torch.min(label_masks)
            label_masks_flat = label_masks_flat / (
                torch.max(label_masks_flat) + 1e-9)
            label_masks_flat = label_masks_flat.view(-1, 1).clamp(0, 1)
            main_loss = self.mask_loss(mask_pred_flat, label_masks_flat)

        elif NLL:
            mask_pred_1 = F.softmax(mask_pred, 1, _stacklevel=5)
            mask_pred_2 = 1 - mask_pred_1
            mask_pred_1 = mask_pred_1.unsqueeze(1)
            mask_pred_2 = mask_pred_2.unsqueeze(1)
            mask_pred = torch.cat((mask_pred_1, mask_pred_2), dim=1)
            label_masks = label_masks.clamp(0, 1)
            if self.is_cuda:
                label_masks = label_masks.type(torch.cuda.LongTensor)
            else:
                label_masks = label_masks.type(torch.LongTensor)
            main_loss = self.mask_loss(mask_pred, label_masks)

        elif CE:
            # Crossentropy2D internally applies logsoftmax to mask_pred,
            # but labels are already assumed to be a valid probability distribution, so no softmax is applied
            main_loss = self.mask_loss(mask_pred, label_masks)
            # So for nice plotting, we must manually do it
            mask_pred = self.spatialsoftmax(mask_pred)
        else:
            main_loss = self.mask_loss(mask_pred, label_masks)

        # sum emb loss if batch size > 1
        if type(emb_loss) == tuple:
            emb_loss = sum(emb_loss)

        # Extract the feature vectors corresponding to every landmark's location in the map
        # Apply a linear layer to classify which of the 64 landmarks it is
        # The landmark positions have to be divided by the same factor as the ResNet scaling factor
        lcount = 0
        for i in range(batch_size):
            if self.class_loss and len(batch["lm_pos"][i]) > 0:
                lcount += 1
                landmark_pos = cuda_var(batch["lm_pos"][i], self.is_cuda,
                                        self.cuda_device)
                landmark_indices = cuda_var(batch["lm_indices"][i],
                                            self.is_cuda, self.cuda_device)
                landmark_coords = (landmark_pos / 8).long()
                lm_features = self.gather2d(features[i:i + 1, 0:32],
                                            landmark_coords)
                lm_pred = self.aux_class_linear(lm_features)
                class_loss = self.aux_loss(lm_pred, landmark_indices)
                total_class_loss = total_class_loss + class_loss

            if self.ground_loss and len(batch["lm_pos"][i]) > 0:
                landmark_pos = cuda_var(batch["lm_pos"][i], self.is_cuda,
                                        self.cuda_device)
                landmark_mentioned = cuda_var(batch["lm_mentioned"][i],
                                              self.is_cuda, self.cuda_device)
                landmark_coords = (landmark_pos / 8).long()
                g_features = self.gather2d(features[i:i + 1, 32:35],
                                           landmark_coords)
                lm_pred = self.aux_ground_linear(g_features)
                ground_loss = self.aux_loss(lm_pred, landmark_mentioned)
                total_ground_loss = total_ground_loss + ground_loss

        total_class_loss = total_class_loss / (lcount + 1e-9)
        total_ground_loss = total_ground_loss / (lcount + 1e-9)
        count += 1

        # Just visualization and debugging code
        if self.get_iter() % 50 == 0:
            presenter = Presenter()
            pred_viz_np = presenter.overlaid_image(images[0].data,
                                                   mask_pred[0].data)
            labl_viz_np = presenter.overlaid_image(images[0].data,
                                                   label_masks[0].data)
            comp = np.concatenate((pred_viz_np, labl_viz_np), axis=1)
            presenter.show_image(comp, "path_pred")

            if hasattr(self.sentence_embedding, "save_att_map"):
                self.sentence_embedding.save_att_map(self.get_iter(), i)

        total_loss = main_loss + 0.1 * total_class_loss + 0.001 * emb_loss + 0.1 * total_ground_loss
        total_loss = total_loss / (count + 1e-9)

        self.write_summaires("eval" if eval else "train", self.get_iter(),
                             total_loss, main_loss, emb_loss, total_class_loss,
                             total_ground_loss)
        self.inc_iter()

        return total_loss
Exemple #7
0
class RolloutVisualizer:
    def __init__(self, resolution=512):
        self.presenter = Presenter()
        self.clear()
        self.current_rollout = {}
        self.current_rollout_name = None
        self.env_image = None
        self.current_timestep = None
        self.world_size_m = P.get_current_parameters()["Setup"]["world_size_m"]

        self.resolution = resolution

    def clear(self):
        self.current_rollout = {}
        self.current_rollout_name = None
        self.env_image = None
        self.current_timestep = None

    def _auto_contrast(self, image):
        import cv2
        image_c = np.clip(image, 0.0, 1.0)
        hsv_image = cv2.cvtColor(image_c, cv2.COLOR_RGB2HSV)
        hsv_image[:, :, 1] *= 1.2
        image_out = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2RGB)
        image_out = np.clip(image_out, 0.0, 1.0)
        #print(image_out.min(), image_out.)
        print(image_out[:, :, 1].min(), image_out[:, :, 1].max())
        return image

    def _integrate_mask(self, frames):
        frames_out = [frames[0]]
        for frame in frames[1:]:
            new_frame_out = np.maximum(frames_out[-1], frame)
            frames_out.append(new_frame_out)
        return frames_out

    def _draw_landmarks(self, image, env_id):
        lm_names, lm_idx, lm_pos = get_landmark_locations_airsim(env_id=env_id)
        image = self.presenter.draw_landmarks(image, lm_names, lm_pos,
                                              self.world_size_m)
        return image

    def load_video_clip(self, env_id, seg_idx, rollout, domain, cam_name,
                        rollout_dir):
        video_path = os.path.join(
            rollout_dir, f"rollout_{cam_name}_{env_id}-0-{seg_idx}.mkv")
        try:
            #if os.path.getsize(video_path) > 1024 * 1024 * 30
            print("Loading video: ", video_path)
            clip = mpy.VideoFileClip(video_path)
        except Exception as e:
            return None
        return clip

    def grab_frames(self,
                    env_id,
                    seg_idx,
                    rollout,
                    domain,
                    frame_name,
                    scale=1):
        frames = []
        for sample in rollout:
            if frame_name == "image":
                frame = sample["state"].image
            elif frame_name == "action":
                action = sample["action"]
                bg = np.zeros((400, 400, 3))
                frame = self.presenter.draw_action(bg,
                                                   offset=(0, 0),
                                                   action=action)
            elif frame_name == "v_dist_r_inner":
                frame_t = sample[frame_name][:3, :, :].transpose((1, 2, 0))
                # TODO: These should come from params
                map_size = 64
                crop_size = 16
                gap = int((map_size - crop_size) / 2)
                crop_l = gap
                crop_r = map_size - gap
                frame_t = frame_t[crop_l:crop_r, crop_l:crop_r, :]
                frame_t[:, :,
                        0] /= (np.percentile(frame_t[:, :, 0], 99) + 1e-9)
                frame_t[:, :,
                        1] /= (np.percentile(frame_t[:, :, 1], 99) + 1e-9)
                frame_t = np.clip(frame_t, 0.0, 1.0)
                shp = list(frame_t.shape)
                shp[2] = 3
                frame = np.zeros(shp)
                frame[:, :, :2] = frame_t
                frame = cv2.resize(frame,
                                   dsize=(self.resolution, self.resolution))
            elif frame_name == "map_struct":
                frame_t = sample[frame_name][:3, :, :].transpose((1, 2, 0))
                shp = list(frame_t.shape)
                shp[2] = 3
                frame = np.zeros(shp)
                frame[:, :, :2] = frame_t
                frame = cv2.resize(frame,
                                   dsize=(self.resolution, self.resolution))
            elif frame_name == "ego_obs_mask":
                frame_t = sample["map_struct"][:3, :, :].transpose((1, 2, 0))
                shp = list(frame_t.shape)
                shp[2] = 3
                canvas = np.zeros(shp)
                canvas[:, :, :] = 1 - frame_t[:, :, 0:1]
                canvas[:, :, :] -= frame_t[:, :, 1:2]
                canvas = np.clip(canvas, 0.0, 1.0)
                frame = cv2.resize(canvas,
                                   dsize=(self.resolution, self.resolution))
            else:
                frame = sample[frame_name][0, :3, :, :].transpose((1, 2, 0))
            if frame_name in ["image", "v_dist_r_inner"]:
                frame -= frame.min()
                frame = frame / (frame.max() + 1e-9)
            else:
                frame -= np.percentile(frame, 0)
                frame /= (np.percentile(frame, 95) + 1e-9)
                frame = np.clip(frame, 0.0, 1.0)
            if scale != 1:
                frame = self.presenter.scale_image(frame, scale)
            frames.append(frame)
        return frames

    def action_visualization(self,
                             env_id,
                             seg_idx,
                             rollout,
                             domain,
                             frame_name="action"):
        frames = []
        for sample in rollout:
            action = sample[frame_name]
            frame = np.ones((200, 200, 3), dtype=np.uint8)
            self.presenter.draw_action(frame, (1, 159), action)
            frames.append(frame)
        return frames

    def overlay_frames(self, under_frames, over_frames, strength=0.5):
        overlaid_frames = [
            self.presenter.overlaid_image(u, o, strength=strength)
            for u, o in zip(under_frames, over_frames)
        ]
        return overlaid_frames

    def top_down_visualization(self, env_id, seg_idx, rollout, domain, params):
        fd = domain == "real"
        obl = domain in ["simulator", "sim"]
        print(domain, obl)
        if params["draw_topdown"]:
            bg_image = load_env_img(
                env_id,
                self.resolution,
                self.resolution,
                real_drone=True if domain == "real" else False,
                origin_bottom_left=obl,
                flipdiag=False,
                alpha=True)
        else:
            bg_image = np.zeros((self.resolution, self.resolution, 3))
        if params["draw_landmarks"]:
            bg_image = self._draw_landmarks(bg_image, env_id)

        # Initialize stuff
        frames = []
        poses_m = []
        poses_px = []
        for sample in rollout:
            sample_image = bg_image.copy()
            frames.append(sample_image)
            state = sample["state"]
            pose_m = state.get_drone_pose()
            pose_px = poses_m_to_px(pose_m,
                                    self.resolution,
                                    self.resolution,
                                    self.world_size_m,
                                    batch_dim=False)
            poses_px.append(pose_px)
            poses_m.append(pose_m)

        instruction = rollout[0]["instruction"]
        print("Instruction: ")
        print(instruction)

        # Draw visitation distributions if requested:
        if params["include_vdist"]:
            print("Drawing visitation distributions")
            if params["ego_vdist"]:
                inner_key = "v_dist_r_inner"
                outer_key = "v_dist_r_outer"
            else:
                inner_key = "v_dist_w_inner"
                outer_key = "v_dist_w_outer"
            for i, sample in enumerate(rollout):
                v_dist_w_inner = np.flipud(sample[inner_key].transpose(
                    (2, 1, 0)))
                # Expand range of each channel separately so that stop entropy doesn't affect how trajectory looks
                v_dist_w_inner[:, :, 0] /= (
                    np.percentile(v_dist_w_inner[:, :, 0], 99.5) + 1e-9)
                v_dist_w_inner[:, :, 1] /= (
                    np.percentile(v_dist_w_inner[:, :, 1], 99.5) + 1e-9)
                v_dist_w_inner = np.clip(v_dist_w_inner, 0.0, 1.0)
                v_dist_w_outer = sample[outer_key]
                if bg_image.max() - bg_image.min() > 1e-9:
                    f = self.presenter.blend_image(frames[i], v_dist_w_inner)
                else:
                    f = self.presenter.overlaid_image(frames[i],
                                                      v_dist_w_inner,
                                                      strength=1.0)
                f = self.presenter.draw_prob_bars(f, v_dist_w_outer)
                frames[i] = f

        if params["include_layer"]:
            layer_name = params["include_layer"]
            print(f"Drawing first 3 channels of layer {layer_name}")
            accumulate = False
            invert = False
            gray = False
            if layer_name == "M_W_accum":
                accumulate = True
                layer_name = "M_W"
            if layer_name == "M_W_accum_inv":
                invert = True
                accumulate = True
                layer_name = "M_W"

            if layer_name.endswith("_Gray"):
                gray = True
                layer_name = layer_name[:-len("_Gray")]

            for i, sample in enumerate(rollout):
                layer = sample[layer_name]
                if len(layer.shape) == 4:
                    layer = layer[0, :, :, :]
                layer = layer.transpose((2, 1, 0))
                layer = np.flipud(layer)
                if layer_name in ["S_W", "F_W"]:
                    layer = layer[:, :, :3]
                else:
                    layer = layer[:, :, :3]
                if layer_name in ["S_W", "R_W", "F_W"]:
                    if gray:
                        layer -= np.percentile(layer, 1)
                        layer /= (np.percentile(layer, 99) + 1e-9)
                    else:
                        layer /= (np.percentile(layer, 97) + 1e-9)
                    layer = np.clip(layer, 0.0, 1.0)

                if layer_name in ["M_W"]:
                    # Having a 0-1 mask does not encode properly with the codec. Add a bit of imperceptible gaussian noise.
                    layer = layer.astype(np.float32)
                    layer = np.tile(layer, (1, 1, 3))

                if accumulate and i > 0:
                    layer = np.maximum(layer, prev_layer)

                prev_layer = layer
                if invert:
                    layer = 1 - layer
                if frames[i].max() > 0.01:
                    frames[i] = self.presenter.blend_image(
                        frames[i], layer[:, :, :3])
                    #frames[i] = self.presenter.overlaid_image(frames[i], layer[:, :, :3])
                else:
                    scale = (int(self.resolution / layer.shape[0]),
                             int(self.resolution / layer.shape[1]))
                    frames[i] = self.presenter.prep_image(layer[:, :, :3],
                                                          scale=scale)

        if params["include_instr"]:
            print("Drawing instruction")
            for i, sample in enumerate(rollout):
                frames[i] = self.presenter.overlay_text(
                    frames[i], sample["instruction"])

        # Draw trajectory history
        if params["draw_trajectory"]:
            print("Drawing trajectory")
            for i, sample in enumerate(rollout):
                history = poses_px[:i + 1]
                position_history = [h.position for h in history]
                frames[i] = self.presenter.draw_trajectory(
                    frames[i], position_history, self.world_size_m)

        # Draw drone
        if params["draw_drone"]:
            print("Drawing drone")
            for i, sample in enumerate(rollout):
                frames[i] = self.presenter.draw_drone(frames[i], poses_m[i],
                                                      self.world_size_m)

        # Draw observability mask:
        if params["draw_fov"]:
            print("Drawing FOV")
            for i, sample in enumerate(rollout):
                frames[i] = self.presenter.draw_observability(
                    frames[i], poses_m[i], self.world_size_m, 84)

        # Visualize
        if False:
            for i, sample in enumerate(rollout):
                self.presenter.show_image(frames[i],
                                          "sample_image",
                                          scale=1,
                                          waitkey=True)

        return frames

    def start_rollout(self,
                      env_id,
                      set_idx,
                      seg_idx,
                      domain,
                      dataset,
                      suffix=""):
        rollout_name = f"{env_id}:{set_idx}:{seg_idx}:{domain}:{dataset}:{suffix}"
        self.current_rollout = {"top-down": []}
        self.current_rollout_name = rollout_name
        self.env_image = load_env_img(512, 512, alpha=True)

    def start_timestep(self, timestep):
        self.current_timestep = timestep
        # Add top-down view image for the new timestep
        self.current_rollout["top-down"].append()
        self.current_rollout["top-down"].append(self.env_image.copy())

    def set_drone_state(self, timestep, state):
        drone_pose = state.get_cam_pose()

        # Draw drone sprite on top_down image
        tdimg = self.current_rollout["top-down"][timestep]
        tdimg_n = self.presenter.draw_drone(
            tdimg, drone_pose,
            P.get_current_parameters()["Setup"]["world_size_m"])

        self.current_rollout["top-down"][timestep] = tdimg_n