def get_predefined_inits(self, init_configs):
        confs = np.load(init_configs)  # [(scene, im_name, cat), (...), ...]
        start_id = 0
        images, scene_annotations, scene_name, scene_scale, targets, path_length = {}, {}, {}, {}, {}, {}
        for conf in confs:
            (scene, im_name, cat) = conf
            annotations, scale, _, _, _ = dh.load_scene_info(
                self.datasetPath, scene)
            scene_annotations[scene] = annotations
            target_lbl = self.cat_dict[cat]
            graph = self.graphs_dict[target_lbl][scene]
            path = nx.shortest_path(graph, im_name, "goal")

            images[start_id] = im_name
            scene_name[start_id] = scene
            scene_scale[start_id] = scale
            targets[start_id] = target_lbl
            path_length[start_id] = len(path)
            start_id += 1

        self.scene_annotations = scene_annotations
        self.images = images
        self.scene_name = scene_name
        self.scene_scale = scene_scale
        self.targets = targets
        self.path_length = path_length
    def sample_episodes_random(self):
        # This function chooses the actions randomly and not through shortest path
        epi_id = 0  # episode id
        im_paths, pose, scene_name, scene_scale = {}, {}, {}, {}
        for scene in self.scene_list:
            annotations, scale, im_names_all, world_poses, directions = dh.load_scene_info(
                self.datasetPath, scene)
            # Create the graph of the environment
            graph = dh.create_scene_graph(annotations,
                                          im_names=im_names_all,
                                          action_set=self.actions)
            scene_epi_count = 0
            while scene_epi_count < self.n_episodes:
                # Randomly select an image index as the starting position
                idx = np.random.randint(len(im_names_all), size=1)
                im_name_0 = im_names_all[idx[0]]
                # organize the episodes into dictionaries holding different information
                poses_epi, path = [], []
                for i in range(self.seq_len):
                    if i == 0:
                        current_im = im_name_0
                    else:
                        # randomly choose the action
                        sel_action = self.actions[np.random.randint(len(
                            self.actions),
                                                                    size=1)[0]]
                        next_im = annotations[current_im][sel_action]
                        if not (next_im == ''):
                            current_im = next_im
                    path.append(current_im)
                    im_idx = np.where(im_names_all == current_im)[0]
                    pos_tmp = world_poses[im_idx][0] * scale  # 3 x 1
                    pose_x_gt = pos_tmp[0, :]
                    pose_z_gt = pos_tmp[2, :]
                    dir_tmp = directions[im_idx][0]  # 3 x 1
                    dir_gt = np.arctan2(dir_tmp[2, :],
                                        dir_tmp[0, :])[0]  # [-pi,pi]
                    poses_epi.append([pose_x_gt, pose_z_gt, dir_gt])

                im_paths[epi_id] = np.asarray(path)
                pose[epi_id] = np.asarray(poses_epi, dtype=np.float32)
                scene_name[epi_id] = scene
                scene_scale[epi_id] = scale
                epi_id += 1
                scene_epi_count += 1

        self.im_paths = im_paths
        self.pose = pose
        self.scene_name = scene_name
        self.scene_scale = scene_scale
def visualize_nav(avd_root, episode_results, scene, save_path):
    # Load the targets info
    targets_file_path = avd_root + "Meta/annotated_targets.npy"
    targets_data = np.load(targets_file_path,
                           encoding='bytes',
                           allow_pickle=True).item()
    # Load scene info
    annotations, scale, im_names_all, world_poses, directions = dh.load_scene_info(
        avd_root, scene)
    # Load the pcloud
    scene_path = avd_root + scene + "/"  # to read the images
    pcloud, color_cloud = get_pcloud(scene_path=scene_path, scale=scale)
    for i in range(len(episode_results)):
        (image_seq, action_seq, cat, done) = episode_results[i]
        image_seq = np.asarray(image_seq)
        #print("image_seq:", image_seq)
        im_list = get_images(imgs_name=image_seq, scene_path=scene_path)
        # get poses of the sequence images
        poses_im = dh.get_image_poses(world_poses, directions, im_names_all,
                                      image_seq, scale)
        # get target images and their poses
        goal_ims = [
            x.decode("utf-8") + ".jpg"
            for x in targets_data[cat.encode()][scene.encode()]
        ]
        if goal_ims[-1] == ".jpg":  # last entry might need fixing
            goal_ims = goal_ims[:-1]
        #print("goal_ims:", goal_ims)
        poses_goal = dh.get_image_poses(world_poses, directions, im_names_all,
                                        np.asarray(goal_ims), scale)

        save_dir = save_path + str(i) + "/"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        fig, ax = plt.subplots(1, 2)
        # visualize the point cloud and the target poses
        ax[0].scatter(pcloud[:, 0], pcloud[:, 2], c=color_cloud / 255.0, s=2)
        ax[0].scatter(poses_goal[:, 0], poses_goal[:, 1], color="green", s=3)
        for j in range(
                im_list.shape[0]):  # add the rest of the steps on the plot
            ax[1].set_title(cat)
            plot_step_nav(ax=ax,
                          k=j,
                          im_list=im_list,
                          poses=poses_im,
                          save_dir=save_dir)
        plt.clf()
    def sample_episodes(self):
        # This function chooses the actions through shortest path
        epi_id = 0  # episode id
        im_paths, pose, scene_name, scene_scale = {}, {}, {}, {}
        for scene in self.scene_list:
            annotations, scale, im_names_all, world_poses, directions = dh.load_scene_info(
                self.datasetPath, scene)
            # Create the graph of the environment
            graph = dh.create_scene_graph(annotations,
                                          im_names=im_names_all,
                                          action_set=self.actions)
            scene_epi_count = 0
            while scene_epi_count < self.n_episodes:
                # Randomly select two nodes and sample a trajectory across their shortest path
                idx = np.random.randint(len(im_names_all), size=2)
                im_name_0 = im_names_all[idx[0]]
                im_name_1 = im_names_all[idx[1]]
                # organize the episodes into dictionaries holding different information
                if nx.has_path(graph, im_name_0, im_name_1):
                    path = nx.shortest_path(
                        graph, im_name_0,
                        im_name_1)  # sequence of nodes leading to goal
                    if len(path) >= self.seq_len:
                        poses_epi = []
                        for i in range(self.seq_len):
                            next_im = path[i]
                            im_idx = np.where(im_names_all == next_im)[0]
                            pos_tmp = world_poses[im_idx][0] * scale  # 3 x 1
                            pose_x_gt = pos_tmp[0, :]
                            pose_z_gt = pos_tmp[2, :]
                            dir_tmp = directions[im_idx][0]  # 3 x 1
                            dir_gt = np.arctan2(
                                dir_tmp[2, :], dir_tmp[0, :]
                            )[0]  # [-pi,pi], assumes that the 0 direction is to the right
                            poses_epi.append([pose_x_gt, pose_z_gt, dir_gt])

                        im_paths[epi_id] = np.asarray(path[:self.seq_len])
                        pose[epi_id] = np.asarray(poses_epi, dtype=np.float32)
                        scene_name[epi_id] = scene
                        scene_scale[epi_id] = scale
                        epi_id += 1
                        scene_epi_count += 1

        self.im_paths = im_paths
        self.pose = pose
        self.scene_name = scene_name
        self.scene_scale = scene_scale
 def sample_starting_points(self):
     # Store all info necessary for a starting position
     start_id = 0
     images, scene_annotations, scene_name, scene_scale, targets, pose, path_length = {}, {}, {}, {}, {}, {}, {}
     for scene in self.scene_list:
         annotations, scale, im_names_all, _, _ = dh.load_scene_info(
             self.datasetPath, scene)
         scene_start_count = 0
         while scene_start_count < self.n_start_pos:
             # Randomly select an image index as the starting position
             idx = np.random.randint(len(im_names_all), size=1)
             im_name_0 = im_names_all[idx[0]]
             # Randomly select a target that exists in that scene
             candidates = dh.candidate_targets(
                 scene, self.cat_dict,
                 self.targets_data)  # Get the list of possible targets
             idx_cat = np.random.randint(len(candidates), size=1)
             cat = candidates[idx_cat[0]]
             target_lbl = self.cat_dict[cat]
             graph = self.graphs_dict[target_lbl][scene]
             path = nx.shortest_path(graph, im_name_0, "goal")
             if len(
                     path
             ) - 2 == 0:  # this means that im_name_0 is a goal location
                 continue
             if len(
                     path
             ) - 1 > self.max_shortest_path:  # limit on the length of episodes
                 continue
             # Add the starting location in the pool
             images[start_id] = im_name_0
             scene_name[start_id] = scene
             scene_scale[start_id] = scale
             targets[start_id] = target_lbl
             path_length[start_id] = len(path)
             scene_start_count += 1
             start_id += 1
         scene_annotations[scene] = annotations
     self.scene_annotations = scene_annotations
     self.images = images
     self.scene_name = scene_name
     self.scene_scale = scene_scale
     self.targets = targets
     self.path_length = path_length
def visualize_loc(avd_root, episode_results, scene, save_path):
    _, scale, _, _, _ = dh.load_scene_info(avd_root, scene)
    # Load the pcloud
    scene_path = avd_root + scene + "/"  # to read the images
    pcloud, color_cloud = get_pcloud(scene_path=scene_path, scale=scale)
    for i in range(len(episode_results)):
        (image_seq, rel_pose, abs_pose, pred_rel_pose, scene,
         _) = episode_results[i]
        # in case anyone wants to show the images as well...
        image_seq = np.asarray(image_seq)
        im_list = get_images(imgs_name=image_seq, scene_path=scene_path)

        init_pose = abs_pose[0, :]
        poses_gt = dh.absolute_poses(rel_pose=rel_pose, origin=init_pose)
        poses_pred = dh.absolute_poses(rel_pose=pred_rel_pose,
                                       origin=init_pose)

        save_dir = save_path + str(i) + "/"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        fig, ax = plt.subplots(1, 1)
        # visualize the point cloud (with color)
        ax.scatter(pcloud[:, 0], pcloud[:, 2], c=color_cloud / 255.0, s=7)
        # show the init pose
        plot_step_loc(ax=ax,
                      k=0,
                      poses_pred=poses_pred,
                      poses_gt=poses_gt,
                      save_dir=save_dir)
        for j in range(
                1, im_list.shape[0]):  # add the rest of the steps on the plot
            plot_step_loc(ax=ax,
                          k=j,
                          poses_pred=poses_pred,
                          poses_gt=poses_gt,
                          save_dir=save_dir)
        plt.clf()
def evaluate_NavNet(parIL, parMapNet, mapNet, ego_encoder, test_iter, test_ids, test_data, action_list):
    print("\nRunning validation on NavNet!")
    with torch.no_grad():
        policy_net = hl.load_model(model_dir=parIL.model_dir, model_name="ILNet", test_iter=test_iter)
        acc, epi_length, path_ratio = 0, 0, 0
        episode_results, episode_count = {}, 0 # store predictions
        for i in test_ids:
            test_ex = test_data[i]
            # Get all info for the starting position
            mapNet_input_start = prepare_mapNet_input(ex=test_ex)
            target_lbl = test_ex["target_lbl"]
            im_obsv = test_ex['image_obsv'].cuda()
            dets_obsv = test_ex['dets_obsv'].cuda()
            tvec = torch.zeros(1, parIL.nTargets).float().cuda()
            tvec[0,target_lbl] = 1
            # We need to keep other info to allow us to do the steps later
            image_name, scene, scale = [], [], []
            image_name.append(test_ex['image_name'])
            scene.append(test_ex['scene'])
            scale.append(test_ex['scale'])
            shortest_path_length = test_ex['path_length']

            if parIL.use_p_gt:
                # get the ground-truth pose, which is the relative pose with respect to the first image
                info, annotations, _ = dh.load_scene_info(parIL.avd_root, scene[0])
                im_names_all = info['image_name'] # info 0 # list of image names in the scene
                im_names_all = np.hstack(im_names_all) # flatten the array
                start_abs_pose = dh.get_image_poses(info, im_names_all, image_name, scale[0]) # init pose of the episode # 1 x 3  

            # Get state from mapNet
            p_, map_ = mapNet.forward_single_step(local_info=mapNet_input_start, t=0, 
                                                    input_flags=parMapNet.input_flags, update_type=parMapNet.update_type)
            collision_ = torch.tensor([0], dtype=torch.float32).cuda() # collision indicator is 0
            if parIL.use_ego_obsv:
                enc_in = torch.cat((im_obsv, dets_obsv), 0).unsqueeze(0)
                ego_obsv_feat = ego_encoder(enc_in) # 1 x 512 x 1 x 1
                state = (map_, p_, tvec, collision_, ego_obsv_feat)
            else:
                state = (map_, p_, tvec, collision_) 
            current_im = image_name[0]

            done=0
            image_seq, action_seq = [], []
            image_seq.append(current_im)
            policy_net.hidden = policy_net.init_hidden(batch_size=1, state_items=len(state)-1)
            for t in range(1, parIL.max_steps+1):
                pred_costs = policy_net(state, parIL.use_ego_obsv) # apply policy for single step
                pred_costs = pred_costs.view(-1).cpu().numpy()
                # choose the action with a certain prob
                pred_probs = softmax(-pred_costs)
                pred_label = np.random.choice(len(action_list), 1, p=pred_probs)[0]
                pred_action = action_list[pred_label]

                # get the next image, check collision and goal
                next_im = test_data.scene_annotations[scene[0]][current_im][pred_action]
                if next_im=='':
                    image_seq.append(current_im)
                else:
                    image_seq.append(next_im)
                action_seq.append(pred_action)
                print(t, current_im, pred_action, next_im)
                if not(next_im==''): # not collision case
                    collision = 0
                    # check for goal
                    path_dist = len(nx.shortest_path(test_data.graphs_dict[target_lbl][scene[0]], next_im, "goal")) - 2
                    if path_dist <= parIL.steps_from_goal: # GOAL!
                        acc += 1
                        epi_length += t
                        path_ratio += t/float(shortest_path_length) # ratio of estimated path towards shortest path
                        done=1
                        break
                    # get next state from mapNet
                    batch_next, obsv_batch_next = test_data.get_step_data(next_ims=[next_im], scenes=scene, scales=scale)
                    if parIL.use_p_gt:
                        next_im_abs_pose = dh.get_image_poses(info, im_names_all, [next_im], scale[0])
                        abs_poses = np.concatenate((start_abs_pose, next_im_abs_pose), axis=0)
                        rel_poses = dh.relative_poses(poses=abs_poses)
                        next_im_rel_pose = np.expand_dims(rel_poses[1,:], axis=0)
                        p_gt = dh.build_p_gt(parMapNet, pose_gt_batch=np.expand_dims(next_im_rel_pose, axis=1)).squeeze(1)
                        p_next, map_next = mapNet.forward_single_step(local_info=batch_next, t=t, input_flags=parMapNet.input_flags,
                                                                map_previous=state[0], p_given=p_gt, update_type=parMapNet.update_type)
                    else:
                        p_next, map_next = mapNet.forward_single_step(local_info=batch_next, t=t, 
                                            input_flags=parMapNet.input_flags, map_previous=state[0], update_type=parMapNet.update_type)
                    if parIL.use_ego_obsv:
                        enc_in = torch.cat(obsv_batch_next, 1)
                        ego_obsv_feat = ego_encoder(enc_in) # b x 512 x 1 x 1
                        state = (map_next, p_next, tvec, torch.tensor([collision], dtype=torch.float32).cuda(), ego_obsv_feat)
                    else:
                        state = (map_next, p_next, tvec, torch.tensor([collision], dtype=torch.float32).cuda())
                    current_im = next_im

                else: # collision case
                    collision = 1
                    if parIL.stop_on_collision:
                        break
                    if parIL.use_ego_obsv:
                        state = (state[0], state[1], state[2], torch.tensor([collision], dtype=torch.float32).cuda(), state[4])
                    else:
                        state = (state[0], state[1], state[2], torch.tensor([collision], dtype=torch.float32).cuda())
                
            episode_results[episode_count] = (image_seq, action_seq, parIL.lbl_to_cat[target_lbl], done)
            episode_count+=1
        # store the episodes
        episode_results_path = parIL.model_dir+'episode_results_eval_'+str(test_iter)+'.pkl'
        with open(episode_results_path, 'wb') as f:
            pickle.dump(episode_results, f)
        
        success_rate = acc / float(len(test_ids))
        if acc > 0:
            mean_epi_length = epi_length / float(acc)
            avg_path_length_ratio = path_ratio / float(acc)
        else:
            mean_epi_length = 0
            avg_path_length_ratio = 0
        print("Test iter:", test_iter, "Success rate:", success_rate)
        print("Mean epi length:", mean_epi_length, "Avg path length ratio:", avg_path_length_ratio)
def unroll_policy(parIL, parMapNet, policy_net, mapNet, action_list,
                  batch_size, seq_len, graphs):
    # Unroll the learned policy to collect online training data
    with torch.no_grad():
        nScenes = 4  # how many scenes to use for this minibatch
        ind = np.random.randint(len(parIL.train_scene_list), size=nScenes)
        scene_list = np.asarray(parIL.train_scene_list)
        sel_scene = scene_list[ind]
        avd_dagger = AVD_online(par=parIL,
                                nStartPos=batch_size / nScenes,
                                scene_list=sel_scene,
                                action_list=action_list,
                                graphs_dict=graphs)
        ########### initialize all the arrays to be returned
        imgs_batch = torch.zeros(batch_size, seq_len, 3,
                                 avd_dagger.cropSize[1],
                                 avd_dagger.cropSize[0]).float().cuda()
        sseg_batch = torch.zeros(batch_size, seq_len, 1,
                                 avd_dagger.cropSize[1],
                                 avd_dagger.cropSize[0]).float().cuda()
        dets_batch = torch.zeros(batch_size, seq_len, avd_dagger.dets_nClasses,
                                 avd_dagger.cropSize[1],
                                 avd_dagger.cropSize[0]).float().cuda()
        imgs_obsv_batch = torch.zeros(
            batch_size, seq_len, 3, avd_dagger.cropSizeObsv[1],
            avd_dagger.cropSizeObsv[0]).float().cuda()
        dets_obsv_batch = torch.zeros(
            batch_size, seq_len, 1, avd_dagger.cropSizeObsv[1],
            avd_dagger.cropSizeObsv[0]).float().cuda()
        tvec_batch = torch.zeros(batch_size, parIL.nTargets).float().cuda()
        pose_gt_batch = np.zeros((batch_size, seq_len, 3), dtype=np.float32)
        collisions_batch = torch.zeros(batch_size, seq_len).float().cuda()
        costs_batch = torch.zeros(batch_size, seq_len,
                                  len(action_list)).float().cuda()
        points2D_batch, local3D_batch = [], []
        image_names_batch, scene_batch, scale_batch, actions = [], [], [], []
        #########################################
        for i in range(len(avd_dagger)):
            ex = avd_dagger[i]
            img = ex["image"].unsqueeze(0)
            points2D_seq, local3D_seq = [], []
            points2D_seq.append(ex["points2D"])
            local3D_seq.append(ex["local3D"])
            sseg = ex["sseg"].unsqueeze(0)
            dets = ex['dets'].unsqueeze(0)
            mapNet_input_start = (img.cuda(), points2D_seq, local3D_seq,
                                  sseg.cuda(), dets.cuda())
            # get all other info needed for the episode
            target_lbl = ex["target_lbl"]
            im_obsv = ex['image_obsv'].cuda()
            dets_obsv = ex['dets_obsv'].cuda()
            tvec = torch.zeros(1, parIL.nTargets).float().cuda()
            tvec[0, target_lbl] = 1
            image_name_seq = []
            image_name_seq.append(ex['image_name'])
            scene = ex['scene']
            scene_batch.append(scene)
            scale = ex['scale']
            scale_batch.append(scale)
            graph = avd_dagger.graphs_dict[target_lbl][scene]
            abs_pose_seq = np.zeros((seq_len, 3), dtype=np.float32)

            annotations, _, im_names_all, world_poses, directions = dh.load_scene_info(
                parIL.avd_root, scene)
            start_abs_pose = dh.get_image_poses(
                world_poses, directions, im_names_all, image_name_seq,
                scale)  # init pose of the episode # 1 x 3
            # Get state from mapNet
            p_, map_ = mapNet.forward_single_step(
                local_info=mapNet_input_start,
                t=0,
                input_flags=parMapNet.input_flags,
                update_type=parMapNet.update_type)
            collision_ = torch.tensor(
                [0], dtype=torch.float32).cuda()  # collision indicator is 0
            if parIL.use_ego_obsv:
                enc_in = torch.cat((im_obsv, dets_obsv), 0).unsqueeze(0)
                ego_obsv_feat = ego_encoder(enc_in)  # 1 x 512 x 1 x 1
                state = (map_, p_, tvec, collision_, ego_obsv_feat)
            else:
                state = (map_, p_, tvec, collision_)
            current_im = image_name_seq[0]  #.copy()

            imgs_batch[i, 0, :, :, :] = img
            sseg_batch[i, 0, :, :, :] = sseg
            dets_batch[i, 0, :, :, :] = dets
            imgs_obsv_batch[i, 0, :, :, :] = im_obsv
            dets_obsv_batch[i, 0, :, :, :] = dets_obsv
            tvec_batch[i] = tvec
            collisions_batch[i, 0] = collision_
            abs_pose_seq[0, :] = start_abs_pose
            cost = np.asarray(dh.get_state_action_cost(current_im, action_list,
                                                       annotations, graph),
                              dtype=np.float32)
            costs_batch[i, 0, :] = torch.from_numpy(cost).float()

            policy_net.hidden = policy_net.init_hidden(batch_size=1,
                                                       state_items=len(state) -
                                                       1)
            for t in range(1, seq_len):
                pred_costs = policy_net(
                    state, parIL.use_ego_obsv)  # apply policy for single step
                pred_costs = pred_costs.view(-1).cpu().numpy()
                # choose the action with the lowest predicted cost
                pred_label = np.argmin(pred_costs)
                pred_action = action_list[pred_label]
                actions.append(pred_action)

                # get the next image, check collision and goal
                next_im = avd_dagger.scene_annotations[scene][current_im][
                    pred_action]
                #print(t, current_im, pred_action, next_im)
                if not (next_im == ''):  # not collision case
                    collision = 0
                    # get next state from mapNet
                    batch_next, obsv_batch_next = avd_dagger.get_step_data(
                        next_ims=[next_im], scenes=[scene], scales=[scale])
                    next_im_abs_pose = dh.get_image_poses(
                        world_poses, directions, im_names_all, [next_im],
                        scale)
                    if parIL.use_p_gt:
                        abs_poses = np.concatenate(
                            (start_abs_pose, next_im_abs_pose), axis=0)
                        rel_poses = dh.relative_poses(poses=abs_poses)
                        next_im_rel_pose = np.expand_dims(rel_poses[1, :],
                                                          axis=0)
                        p_gt = dh.build_p_gt(
                            parMapNet,
                            pose_gt_batch=np.expand_dims(next_im_rel_pose,
                                                         axis=1)).squeeze(1)
                        p_next, map_next = mapNet.forward_single_step(
                            local_info=batch_next,
                            t=t,
                            input_flags=parMapNet.input_flags,
                            map_previous=state[0],
                            p_given=p_gt,
                            update_type=parMapNet.update_type)
                    else:
                        p_next, map_next = mapNet.forward_single_step(
                            local_info=batch_next,
                            t=t,
                            input_flags=parMapNet.input_flags,
                            map_previous=state[0],
                            update_type=parMapNet.update_type)
                    if parIL.use_ego_obsv:
                        enc_in = torch.cat(obsv_batch_next, 1)
                        ego_obsv_feat = ego_encoder(enc_in)  # b x 512 x 1 x 1
                        state = (map_next, p_next, tvec,
                                 torch.tensor([collision],
                                              dtype=torch.float32).cuda(),
                                 ego_obsv_feat)
                    else:
                        state = (map_next, p_next, tvec,
                                 torch.tensor([collision],
                                              dtype=torch.float32).cuda())
                    current_im = next_im

                    # store the data in the batch
                    (imgs_next, points2D_next, local3D_next, sseg_next,
                     dets_next) = batch_next
                    (imgs_obsv_next, dets_obsv_next) = obsv_batch_next
                    imgs_batch[i, t, :, :, :] = imgs_next
                    sseg_batch[i, t, :, :, :] = sseg_next
                    dets_batch[i, t, :, :, :] = dets_next
                    imgs_obsv_batch[i, t, :, :, :] = imgs_obsv_next
                    dets_obsv_batch[i, t, :, :, :] = dets_obsv_next
                    collisions_batch[i, t] = torch.tensor([collision],
                                                          dtype=torch.float32)
                    abs_pose_seq[t, :] = next_im_abs_pose
                    cost = np.asarray(dh.get_state_action_cost(
                        current_im, action_list, annotations, graph),
                                      dtype=np.float32)
                    costs_batch[i, t, :] = torch.from_numpy(cost).float()
                    image_name_seq.append(current_im)
                    points2D_seq.append(points2D_next[0])
                    local3D_seq.append(local3D_next[0])

                else:  # collision case
                    collision = 1
                    if parIL.stop_on_collision:
                        break
                    if parIL.use_ego_obsv:
                        state = (state[0], state[1], state[2],
                                 torch.tensor([collision],
                                              dtype=torch.float32).cuda(),
                                 state[4])
                    else:
                        state = (state[0], state[1], state[2],
                                 torch.tensor([collision],
                                              dtype=torch.float32).cuda())
                    # store the data for the collision case (use the ones from the previous step)
                    imgs_batch[i, t, :, :, :] = imgs_batch[i, t - 1, :, :, :]
                    sseg_batch[i, t, :, :, :] = sseg_batch[i, t - 1, :, :, :]
                    dets_batch[i, t, :, :, :] = dets_batch[i, t - 1, :, :, :]
                    imgs_obsv_batch[i,
                                    t, :, :, :] = imgs_obsv_batch[i, t -
                                                                  1, :, :, :]
                    dets_obsv_batch[i,
                                    t, :, :, :] = dets_obsv_batch[i, t -
                                                                  1, :, :, :]
                    collisions_batch[i, t] = torch.tensor([collision],
                                                          dtype=torch.float32)
                    abs_pose_seq[t, :] = abs_pose_seq[t - 1, :]
                    costs_batch[i, t, :] = costs_batch[i, t - 1, :]
                    image_name_seq.append(current_im)
                    points2D_seq.append(points2D_seq[t - 1])
                    local3D_seq.append(local3D_seq[t - 1])

            # Do the relative pose
            pose_gt_batch[i] = dh.relative_poses(poses=abs_pose_seq)

            # add the remaining batch data where necessary (i.e. lists)
            image_names_batch.append(image_name_seq)
            points2D_batch.append(points2D_seq)
            local3D_batch.append(local3D_seq)

        actions = np.asarray(actions)
        image_names_batch = np.asarray(image_names_batch)

        mapNet_batch = (imgs_batch, points2D_batch, local3D_batch, sseg_batch,
                        dets_batch, pose_gt_batch)
        IL_batch = (imgs_obsv_batch, dets_obsv_batch, tvec_batch,
                    collisions_batch, actions, costs_batch, image_names_batch,
                    scene_batch, scale_batch)
        return mapNet_batch, IL_batch
    def sample_episodes(self):
        # Each episode should contain:
        # List of images, list of actions, cost of every action, scene, scale, collision indicators
        epi_id = 0  # episode id
        im_paths, action_paths, cost_paths, scene_name, scene_scale, target_lbls, pose_paths, collisions = {}, {}, {}, {}, {}, {}, {}, {}
        for scene in self.scene_list:
            annotations, scale, im_names_all, world_poses, directions = dh.load_scene_info(
                self.datasetPath, scene)
            scene_epi_count = 0
            while scene_epi_count < self.n_episodes:
                # Randomly select an image index as the starting position
                idx = np.random.randint(len(im_names_all), size=1)
                im_name_0 = im_names_all[idx[0]]
                # Randomly select a target that exists in that scene
                candidates = dh.candidate_targets(
                    scene, self.cat_dict,
                    self.targets_data)  # Get the list of possible targets
                idx_cat = np.random.randint(len(candidates), size=1)
                cat = candidates[idx_cat[0]]
                target_lbl = self.cat_dict[cat]
                graph = self.graphs_dict[target_lbl][
                    scene]  # to be used to get the ground-truth
                # Choose whether the episode's observations are going to be decided by the
                # teacher (best action) or randomly
                choice = np.random.randint(2,
                                           size=1)[0]  # if 1 then do teacher
                im_seq, action_seq, cost_seq, poses_seq, collision_seq = [], [], [], [], []
                im_seq.append(im_name_0)
                current_im = im_name_0
                # get the ground-truth cost for each next state
                cost_seq.append(
                    dh.get_state_action_cost(current_im, self.actions,
                                             annotations, graph))
                poses_seq.append(
                    dh.get_im_pose(im_names_all, current_im, world_poses,
                                   directions, scale))
                collision_seq.append(0)
                for i in range(1, self.seq_len):
                    # either select the best action or ...
                    # ... randomly choose the next action to move in the episode
                    if choice:
                        actions_cost = np.array(cost_seq[i - 1])
                        min_cost = np.min(actions_cost)
                        min_ind = np.where(actions_cost == min_cost)[0]
                        if len(min_ind) == 1:
                            sel_ind = min_ind[0]
                        else:  # if multiple actions have the lowest value then randomly select one
                            sel_ind = min_ind[np.random.randint(len(min_ind),
                                                                size=1)[0]]
                        sel_action = self.actions[sel_ind]
                    else:
                        sel_action = self.actions[np.random.randint(len(
                            self.actions),
                                                                    size=1)[0]]
                    next_im = annotations[current_im][sel_action]
                    if not (
                            next_im == ''
                    ):  # if there is a collision then keep the same image
                        current_im = next_im
                        collision_seq.append(0)
                    else:
                        collision_seq.append(1)
                    im_seq.append(current_im)
                    action_seq.append(sel_action)
                    # get the ground-truth pose
                    poses_seq.append(
                        dh.get_im_pose(im_names_all, current_im, world_poses,
                                       directions, scale))
                    cost_seq.append(
                        dh.get_state_action_cost(current_im, self.actions,
                                                 annotations, graph))

                im_paths[epi_id] = np.asarray(im_seq)
                action_paths[epi_id] = np.asarray(action_seq)
                cost_paths[epi_id] = np.asarray(cost_seq, dtype=np.float32)
                scene_name[epi_id] = scene
                scene_scale[epi_id] = scale
                target_lbls[epi_id] = target_lbl
                pose_paths[epi_id] = np.asarray(poses_seq, dtype=np.float32)
                collisions[epi_id] = np.asarray(collision_seq,
                                                dtype=np.float32)
                epi_id += 1
                scene_epi_count += 1

        self.im_paths = im_paths
        self.action_paths = action_paths
        self.cost_paths = cost_paths
        self.scene_name = scene_name
        self.scene_scale = scene_scale
        self.target_lbls = target_lbls
        self.pose_paths = pose_paths
        self.collisions = collisions