def estimate_pose(self, pairwise_poses_world):
        """
        Inputs:
            pairwise_poses_world - (K, bs, 3) - (x, y, theta) in world coordinates
        Outputs:
            final_pose    - (bs, 3) -- pose
            final_position- (bs, 3) -- (x, y, theta) in world coordinates
        """
        K, bs = pairwise_poses_world.shape[:2]
        # Create the voting map based on the predicted poses
        # Convert to (r, phi, theta) coordinates
        pred_poses_polar = xyt2polar(
            flatten_two(pairwise_poses_world))  # (K*bs, 3)
        pred_poses_map = self.polar2map(pred_poses_polar)  # (K*bs, 3)
        voting_maps = self.map2votes(pred_poses_map)  # (K*bs, 1, mh, mw)
        pairwise_angles = flatten_two(pairwise_poses_world[:, :,
                                                           2])  # (K*bs, )

        voting_maps = unflatten_two(voting_maps, K, bs)  # (K, bs, 1, mh, mw)
        pairwise_angles = pairwise_angles.view(K, bs, 1)  # (K, bs, 1)
        final_pose, _ = self.pose_head.forward(
            voting_maps)  # (bs, num_poses), (bs, )
        final_position = self.pose_head.get_position_and_pose(
            voting_maps, pairwise_angles)  # (bs, 3)
        return final_pose, final_position
    def estimate_pose_mask(self, pairwise_poses_world, masks):
        """
        Inputs:
            pairwise_poses_world - (K, bs, 3) - (x, y, theta) in world coordinates
            masks         - (K, bs) - binary indicating which of the K elements are to be considered
        Outputs:
            final_pose    - (bs, 3) -- pose
            final_position- (bs, 3) -- (x, y, theta) in world coordinates
            voting_map    - (bs, 1, mh, mw)

        """
        K, bs = pairwise_poses_world.shape[:2]
        # Create the voting map based on the predicted poses
        # Convert to (r, phi, theta) coordinates
        pred_poses_polar = xyt2polar(
            flatten_two(pairwise_poses_world))  # (K*bs, 3)
        pred_poses_map = self.polar2map(pred_poses_polar)  # (K*bs, 3)
        voting_maps = self.map2votes(pred_poses_map)  # (K*bs, 1, mh, mw)
        pairwise_angles = flatten_two(pairwise_poses_world[:, :,
                                                           2])  # (K*bs, )
        # Mask out the irrelevant samples
        voting_maps = voting_maps * masks.view(K * bs, 1, 1, 1)

        voting_maps = unflatten_two(voting_maps, K, bs)  # (K, bs, 1, mh, mw)
        pairwise_angles = pairwise_angles.view(K, bs, 1)  # (K, bs, 1)
        final_pose, _ = self.pose_head.forward(
            voting_maps)  # (bs, num_poses), (bs, )
        final_position = self.pose_head.get_position_and_pose(
            voting_maps, pairwise_angles)  # (bs, 3)
        voting_maps = voting_maps.sum(dim=0)  # (bs, 1, mh, mw)
        voting_maps = voting_maps + 1e-9  # Add small non-zero votes to all locations
        voting_sum = voting_maps.view(bs, -1).sum(dim=1)  # (bs, )
        voting_sum = voting_sum.view(bs, 1, 1, 1)  # (bs, 1, 1, 1)
        voting_map = voting_maps / voting_sum  # (bs, 1, mh, mw)

        return final_pose, final_position, voting_map
def compute_reconstruction_rewards(
    obs_feats,
    obs_odometer,
    tgt_feats,
    tgt_poses,
    cluster_centroids_t,
    decoder,
    pose_encoder,
):
    """
    Inputs:
        obs_feats           - (T, N, nclusters)
        obs_odometer        - (T, N, 3) --- (y, x, theta)
        tgt_feats           - (N, nRef, nclusters)
        tgt_poses           - (N, nRef, 3) --- (y, x, theta)
        cluster_centroids_t - (nclusters, feat_dim)
        decoder             - decoder model
        pose_encoder        - pose_encoder model

    Outputs:
        reward              - (N, nRef) float values indicating how many
                              GT clusters were successfully retrieved for
                              each target.
    """
    T, N, nclusters = obs_feats.shape
    nRef = tgt_feats.shape[1]
    device = obs_feats.device

    obs_feats_exp = obs_feats.unsqueeze(2)
    obs_feats_exp = obs_feats_exp.expand(
        -1, -1, nRef, -1).contiguous()  # (T, N, nRef, nclusters)
    obs_odometer_exp = obs_odometer.unsqueeze(2)
    obs_odometer_exp = obs_odometer_exp.expand(
        -1, -1, nRef, -1).contiguous()  # (T, N, nRef, 3)
    tgt_poses_exp = (tgt_poses.unsqueeze(0).expand(T, -1, -1, -1).contiguous()
                     )  # (T, N, nRef, 3)

    # Compute relative poses
    obs_odometer_exp = obs_odometer_exp.view(T * N * nRef, 3)
    tgt_poses_exp = tgt_poses_exp.view(T * N * nRef, 3)
    obs_relpose = subtract_pose(obs_odometer_exp,
                                tgt_poses_exp)  # (T*N*nRef, 3) --- (x, y, phi)

    # Compute pose encoding
    with torch.no_grad():
        obs_relpose_enc = pose_encoder(obs_relpose)  # (T*N*nRef, 16)
    obs_relpose_enc = obs_relpose_enc.view(T, N, nRef, -1)  # (T, N, nRef, 16)
    tgt_relpose_enc = torch.zeros(1, *obs_relpose_enc.shape[1:]).to(
        device)  # (1, N, nRef, 16)

    # Compute reconstructions
    obs_feats_exp = obs_feats_exp.view(T, N * nRef, nclusters)
    obs_relpose_enc = obs_relpose_enc.view(T, N * nRef, -1)
    tgt_relpose_enc = tgt_relpose_enc.view(1, N * nRef, -1)

    rec_inputs = {
        "history_image_features": obs_feats_exp,
        "history_pose_features": obs_relpose_enc,
        "target_pose_features": tgt_relpose_enc,
    }

    with torch.no_grad():
        pred_logits = decoder(rec_inputs)  # (1, N*nRef, nclusters)
    pred_logits = pred_logits.squeeze(0)  # (N*nRef, nclusters)
    pred_logits = unflatten_two(pred_logits, N, nRef)  # (N, nRef, nclusters)

    # Compute GT classes
    tgt_feats_sim = tgt_feats  # (N, nRef, nclusters)
    topk_gt = torch.topk(tgt_feats_sim, 5, dim=2)
    topk_gt_values = topk_gt.values  # (N, nRef, nclusters)
    topk_gt_thresh = topk_gt_values.min(dim=2).values  # (N, nRef)

    # ------------------ KL Div loss based reward --------------------
    reward = -rec_loss_fn_classify(
        flatten_two(pred_logits),
        flatten_two(tgt_feats),
        cluster_centroids_t.t(),
        K=2,
        reduction="none",
    ).sum(dim=1)  # (N*nRef, )
    reward = reward.view(N, nRef)

    return reward
 NPROC = args.num_processes
 NREF = args.num_pose_refs
 for j in range(j_start + 1, num_updates):
     # =================== Start a new episode ====================
     obs = envs.reset()
     # Processing environment inputs
     obs_im = get_obs(obs)  # (num_processes, 3, 84, 84)
     obs_odometer = process_odometer(obs["delta"])  # (num_processes, 4)
     # Convert mm to m for AVD
     if "avd" in args.env_name:
         obs_odometer[:, :2] /= 1000.0
     # ============== Target poses and corresponding images ================
     # NOTE - these are constant throughout the episode.
     # (num_processes * num_pose_refs, 3) --- (y, x, t)
     tgt_poses = process_odometer(flatten_two(obs["pose_regress"]))[:, :3]
     tgt_poses = unflatten_two(tgt_poses, NPROC, NREF)  # (N, nRef, 3)
     tgt_masks = obs["valid_masks"].unsqueeze(2)  # (N, nRef, 1)
     # Convert mm to m for AVD
     if "avd" in args.env_name:
         tgt_poses[:, :, :2] /= 1000.0
     tgt_ims = process_image(flatten_two(obs["pose_refs"]))  # (N*nRef, C, H, W)
     # Initialize the memory of rollouts
     rollouts.reset()
     with torch.no_grad():
         obs_feat = feature_network(obs_im)  # (N, 2048)
         tgt_feat = feature_network(tgt_ims)  # (N*nRef, 2048)
         # Compute similarity scores with all other clusters
         obs_feat = torch.matmul(obs_feat, cluster_centroids.t())  # (N, nclusters)
         tgt_feat = torch.matmul(
             tgt_feat, cluster_centroids.t()
         )  # (N*nRef, nclusters)
def main():
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    ndevices = torch.cuda.device_count()
    # Setup loggers
    tbwriter = SummaryWriter(log_dir=args.log_dir)
    logging.basicConfig(filename=f"{args.log_dir}/train_log.txt",
                        level=logging.DEBUG)
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.getLogger().setLevel(logging.INFO)
    if "habitat" in args.env_name:
        devices = [
            int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        ]
        # Devices need to be indexed between 0 to N-1
        devices = [dev for dev in range(len(devices))]
        envs = make_vec_envs_habitat(args.habitat_config_file,
                                     device,
                                     devices,
                                     seed=args.seed)
    else:
        train_log_dir = os.path.join(args.log_dir, "train_monitor")
        try:
            os.makedirs(train_log_dir)
        except OSError:
            pass
        envs = make_vec_envs_avd(
            args.env_name,
            args.seed,
            args.num_processes,
            train_log_dir,
            device,
            True,
            num_frame_stack=1,
            split="train",
            nRef=args.num_pose_refs,
        )

    args.feat_shape_sim = (512, )
    args.obs_shape = envs.observation_space.spaces["im"].shape
    args.odometer_shape = (4, )  # (delta_y, delta_x, delta_head, delta_elev)

    # =================== Load clusters =================
    clusters_h5 = h5py.File(args.clusters_path, "r")
    cluster_centroids = torch.Tensor(np.array(
        clusters_h5["cluster_centroids"])).to(device)
    cluster_centroids_t = cluster_centroids.t()
    args.nclusters = cluster_centroids.shape[0]
    clusters2images = {}
    for i in range(args.nclusters):
        cluster_images = np.array(
            clusters_h5[f"cluster_{i}/images"])  # (K, C, H, W) torch Tensor
        cluster_images = rearrange(cluster_images, "k c h w -> k h w c")
        cluster_images = (cluster_images * 255.0).astype(np.uint8)
        clusters2images[i] = cluster_images  # (K, H, W, C)
    clusters_h5.close()

    # =================== Create models ====================
    decoder = FeatureReconstructionModule(
        args.nclusters,
        args.nclusters,
        nlayers=args.n_transformer_layers,
    )
    feature_network = FeatureNetwork()
    pose_encoder = PoseEncoder()
    encoder = RGBEncoder() if args.encoder_type == "rgb" else MapRGBEncoder()
    action_config = ({
        "nactions": envs.action_space.n,
        "embedding_size": args.action_embedding_size
    } if args.use_action_embedding else None)
    collision_config = ({
        "collision_dim": 2,
        "embedding_size": args.collision_embedding_size
    } if args.use_collision_embedding else None)
    actor_critic = Policy(
        envs.action_space,
        base_kwargs={
            "feat_dim": args.feat_shape_sim[0],
            "recurrent": True,
            "hidden_size": args.feat_shape_sim[0],
            "action_config": action_config,
            "collision_config": collision_config,
        },
    )

    # =================== Load models ====================
    decoder_state, pose_encoder_state = torch.load(args.load_path_rec)[:2]
    # Remove DataParallel related strings
    new_decoder_state, new_pose_encoder_state = {}, {}
    for k, v in decoder_state.items():
        new_decoder_state[k.replace("module.", "")] = v
    for k, v in pose_encoder_state.items():
        new_pose_encoder_state[k.replace("module.", "")] = v
    decoder.load_state_dict(new_decoder_state)
    pose_encoder.load_state_dict(new_pose_encoder_state)
    decoder = nn.DataParallel(decoder, dim=1)
    pose_encoder = nn.DataParallel(pose_encoder, dim=0)
    save_path = os.path.join(args.save_dir, "checkpoints")
    checkpoint_path = os.path.join(save_path, "ckpt.latest.pth")
    if os.path.isfile(checkpoint_path):
        print("Resuming from old model!")
        loaded_states = torch.load(checkpoint_path)
        encoder_state, actor_critic_state, j_start = loaded_states
        encoder.load_state_dict(encoder_state)
        actor_critic.load_state_dict(actor_critic_state)
    elif args.pretrained_il_model != "":
        logging.info("Initializing with pre-trained model!")
        encoder_state, actor_critic_state, _ = torch.load(
            args.pretrained_il_model)
        actor_critic.load_state_dict(actor_critic_state)
        encoder.load_state_dict(encoder_state)
        j_start = -1
    else:
        j_start = -1
    encoder.to(device)
    actor_critic.to(device)
    decoder.to(device)
    feature_network.to(device)
    pose_encoder.to(device)
    encoder.train()
    actor_critic.train()
    # decoder, feature_network, pose_encoder are frozen during policy training
    decoder.eval()
    feature_network.eval()
    pose_encoder.eval()

    # =================== Define RL training algorithm ====================
    rl_algo_config = {}
    rl_algo_config["lr"] = args.lr
    rl_algo_config["eps"] = args.eps
    rl_algo_config["encoder_type"] = args.encoder_type
    rl_algo_config["max_grad_norm"] = args.max_grad_norm
    rl_algo_config["clip_param"] = args.clip_param
    rl_algo_config["ppo_epoch"] = args.ppo_epoch
    rl_algo_config["entropy_coef"] = args.entropy_coef
    rl_algo_config["num_mini_batch"] = args.num_mini_batch
    rl_algo_config["value_loss_coef"] = args.value_loss_coef
    rl_algo_config["use_clipped_value_loss"] = False
    rl_algo_config["nactions"] = envs.action_space.n

    rl_algo_config["encoder"] = encoder
    rl_algo_config["actor_critic"] = actor_critic
    rl_algo_config["use_action_embedding"] = args.use_action_embedding
    rl_algo_config["use_collision_embedding"] = args.use_collision_embedding

    rl_agent = PPO(rl_algo_config)

    # =================== Define rollouts ====================
    rollouts_recon = RolloutStorageReconstruction(
        args.num_steps,
        args.num_processes,
        (args.nclusters, ),
        args.odometer_shape,
        args.num_pose_refs,
    )
    rollouts_policy = RolloutStoragePPO(
        args.num_rl_steps,
        args.num_processes,
        args.obs_shape,
        envs.action_space,
        args.feat_shape_sim[0],
        encoder_type=args.encoder_type,
    )
    rollouts_recon.to(device)
    rollouts_policy.to(device)

    def get_obs(obs):
        obs_im = process_image(obs["im"])
        if args.encoder_type == "rgb+map":
            obs_lm = process_image(obs["coarse_occupancy"])
            obs_sm = process_image(obs["fine_occupancy"])
        else:
            obs_lm = None
            obs_sm = None
        return obs_im, obs_sm, obs_lm

    start = time.time()
    NPROC = args.num_processes
    NREF = args.num_pose_refs
    for j in range(j_start + 1, num_updates):
        # =================== Start a new episode ====================
        obs = envs.reset()
        # Processing environment inputs
        obs_im, obs_sm, obs_lm = get_obs(obs)  # (num_processes, 3, 84, 84)
        obs_odometer = process_odometer(obs["delta"])  # (num_processes, 4)
        # Convert mm to m for AVD
        if "avd" in args.env_name:
            obs_odometer[:, :2] /= 1000.0
        obs_collns = obs["collisions"].long()  # (num_processes, 1)
        # ============== Target poses and corresponding images ================
        # NOTE - these are constant throughout the episode.
        # (num_processes * num_pose_refs, 3) --- (y, x, t)
        tgt_poses = process_odometer(flatten_two(obs["pose_regress"]))[:, :3]
        tgt_poses = unflatten_two(tgt_poses, NPROC, NREF)  # (N, nRef, 3)
        tgt_masks = obs["valid_masks"].unsqueeze(2)  # (N, nRef, 1)
        # Convert mm to m for AVD
        if "avd" in args.env_name:
            tgt_poses[:, :, :2] /= 1000.0
        tgt_ims = process_image(flatten_two(
            obs["pose_refs"]))  # (N*nRef, C, H, W)
        # Initialize the memory of rollouts for reconstruction
        rollouts_recon.reset()
        with torch.no_grad():
            obs_feat = feature_network(obs_im)  # (N, 2048)
            tgt_feat = feature_network(tgt_ims)  # (N*nRef, 2048)
            # Compute similarity scores with all other clusters
            obs_feat = torch.matmul(obs_feat,
                                    cluster_centroids_t)  # (N, nclusters)
            tgt_feat = torch.matmul(tgt_feat,
                                    cluster_centroids_t)  # (N*nRef, nclusters)
        tgt_feat = unflatten_two(tgt_feat, NPROC, NREF)  # (N, nRef, nclusters)
        rollouts_recon.obs_feats[0].copy_(obs_feat)
        rollouts_recon.obs_odometer[0].copy_(obs_odometer)
        rollouts_recon.tgt_poses.copy_(tgt_poses)
        rollouts_recon.tgt_feats.copy_(tgt_feat)
        rollouts_recon.tgt_masks.copy_(tgt_masks)
        # Initialize the memory of rollouts for policy
        rollouts_policy.reset()
        rollouts_policy.obs_im[0].copy_(obs_im)
        if args.encoder_type == "rgb+map":
            rollouts_policy.obs_sm[0].copy_(obs_sm)
            rollouts_policy.obs_lm[0].copy_(obs_lm)
        rollouts_policy.collisions[0].copy_(obs_collns)
        # Episode statistics
        episode_expl_rewards = np.zeros((NPROC, 1))
        episode_collisions = np.zeros((NPROC, 1))
        episode_rec_rewards = np.zeros((NPROC, 1))
        episode_collisions += obs_collns.cpu().numpy()
        # Metrics
        osr_tracker = [0.0 for _ in range(NPROC)]
        objects_tracker = [0.0 for _ in range(NPROC)]
        area_tracker = [0.0 for _ in range(NPROC)]
        novelty_tracker = [0.0 for _ in range(NPROC)]
        smooth_coverage_tracker = [0.0 for _ in range(NPROC)]
        per_proc_area = [0.0 for _ in range(NPROC)]
        # Other states
        prev_action = torch.zeros(NPROC, 1).long().to(device)
        prev_collision = rollouts_policy.collisions[0]
        rec_reward_interval = args.rec_reward_interval
        prev_rec_rewards = torch.zeros(NPROC, 1)  # (N, 1)
        prev_rec_rewards = prev_rec_rewards.to(device)
        rec_rewards_at_t0 = None
        # ================= Update over a full batch of episodes =================
        # num_steps must be total number of steps in each episode
        for step in range(args.num_steps):
            pstep = rollouts_policy.step
            with torch.no_grad():
                encoder_inputs = [rollouts_policy.obs_im[pstep]]
                if args.encoder_type == "rgb+map":
                    encoder_inputs.append(rollouts_policy.obs_sm[pstep])
                    encoder_inputs.append(rollouts_policy.obs_lm[pstep])
                obs_feats = encoder(*encoder_inputs)
                policy_inputs = {"features": obs_feats}
                if args.use_action_embedding:
                    policy_inputs["actions"] = prev_action.long()
                if args.use_collision_embedding:
                    policy_inputs["collisions"] = prev_collision.long()

                policy_outputs = actor_critic.act(
                    policy_inputs,
                    rollouts_policy.recurrent_hidden_states[pstep],
                    rollouts_policy.masks[pstep],
                )
                (
                    value,
                    action,
                    action_log_probs,
                    recurrent_hidden_states,
                ) = policy_outputs

            # Act, get reward and next obs
            obs, reward, done, infos = envs.step(action)

            # Processing environment inputs
            obs_im, obs_sm, obs_lm = get_obs(obs)  # (num_processes, 3, 84, 84)
            obs_odometer = process_odometer(obs["delta"])  # (num_processes, 4)
            if "avd" in args.env_name:
                obs_odometer[:, :2] /= 1000.0
            obs_collns = obs["collisions"]  # (N, 1)
            with torch.no_grad():
                obs_feat = feature_network(obs_im)
                # Compute similarity scores with all other clusters
                obs_feat = torch.matmul(obs_feat,
                                        cluster_centroids_t)  # (N, nclusters)

            # Always set masks to 1 (since this loop happens within one episode)
            masks = torch.FloatTensor([[1.0] for _ in range(NPROC)]).to(device)

            # Accumulate odometer readings to give relative pose from the starting point
            obs_odometer = rollouts_recon.obs_odometer[
                step] * masks + obs_odometer

            # Update rollouts_recon
            rollouts_recon.insert(obs_feat, obs_odometer)

            # Compute the exploration rewards
            reward_exploration = torch.zeros(NPROC, 1)  # (N, 1)
            for proc in range(NPROC):
                seen_area = float(infos[proc]["seen_area"])
                objects_visited = infos[proc].get("num_objects_visited", 0.0)
                oracle_success = float(infos[proc]["oracle_pose_success"])
                novelty_reward = infos[proc].get("count_based_reward", 0.0)
                smooth_coverage_reward = infos[proc].get(
                    "coverage_novelty_reward", 0.0)
                area_reward = seen_area - area_tracker[proc]
                objects_reward = objects_visited - objects_tracker[proc]
                landmarks_reward = oracle_success - osr_tracker[proc]
                collision_reward = -obs_collns[proc, 0].item()

                area_tracker[proc] = seen_area
                objects_tracker[proc] = objects_visited
                osr_tracker[proc] = oracle_success
                per_proc_area[proc] = seen_area
                novelty_tracker[proc] += novelty_reward
                smooth_coverage_tracker[proc] += smooth_coverage_reward

            # Compute reconstruction rewards
            if (step + 1) % rec_reward_interval == 0 or step == 0:
                rec_rewards = compute_reconstruction_rewards(
                    rollouts_recon.obs_feats[:(step + 1)],
                    rollouts_recon.obs_odometer[:(step + 1), :, :3],
                    rollouts_recon.tgt_feats,
                    rollouts_recon.tgt_poses,
                    cluster_centroids_t,
                    decoder,
                    pose_encoder,
                ).detach()  # (N, nRef)
                rec_rewards = rec_rewards * tgt_masks.squeeze(2)  # (N, nRef)
                rec_rewards = rec_rewards.sum(dim=1).unsqueeze(
                    1)  # / (tgt_masks.sum(dim=1) + 1e-8)
                final_rec_rewards = rec_rewards - prev_rec_rewards
                # if step == 0:
                #    print(
                #        "==============================================================="
                #    )
                # Ignore the exploration reward at T=0 since it will be a huge spike
                if (("avd" in args.env_name) and
                    (step != 0)) or (("habitat" in args.env_name) and
                                     (step > 20)):
                    # print(
                    #    "Rec rewards[0]: {:.2f}".format(final_rec_rewards[0, 0].item())
                    # )
                    reward_exploration += (final_rec_rewards.cpu() *
                                           args.rec_reward_scale)
                    episode_rec_rewards += final_rec_rewards.cpu().numpy()
                prev_rec_rewards = rec_rewards

            overall_reward = (reward * (1 - args.reward_scale) +
                              reward_exploration * args.reward_scale)

            # Update statistics
            episode_expl_rewards += reward_exploration.numpy(
            ) * args.reward_scale

            # Update rollouts_policy
            rollouts_policy.insert(
                obs_im,
                obs_sm,
                obs_lm,
                recurrent_hidden_states,
                action,
                action_log_probs,
                value,
                overall_reward,
                masks,
                obs_collns,
            )

            # Update prev values
            prev_collision = obs_collns
            prev_action = action
            episode_collisions += obs_collns.cpu().numpy()

            # Update RL policy
            if (step + 1) % args.num_rl_steps == 0:
                # Update value function for last step
                with torch.no_grad():
                    encoder_inputs = [rollouts_policy.obs_im[-1]]
                    if args.encoder_type == "rgb+map":
                        encoder_inputs.append(rollouts_policy.obs_sm[-1])
                        encoder_inputs.append(rollouts_policy.obs_lm[-1])
                    obs_feats = encoder(*encoder_inputs)
                    policy_inputs = {"features": obs_feats}
                    if args.use_action_embedding:
                        policy_inputs["actions"] = prev_action.long()
                    if args.use_collision_embedding:
                        policy_inputs["collisions"] = prev_collision.long()
                    next_value = actor_critic.get_value(
                        policy_inputs,
                        rollouts_policy.recurrent_hidden_states[-1],
                        rollouts_policy.masks[-1],
                    ).detach()
                # Compute returns
                rollouts_policy.compute_returns(next_value, args.use_gae,
                                                args.gamma, args.tau)

                # Update model
                rl_losses = rl_agent.update(rollouts_policy)

                # Refresh rollouts_policy
                rollouts_policy.after_update()

        # =================== Save model ====================
        if (j + 1) % args.save_interval == 0 and args.save_dir != "":
            save_path = f"{args.save_dir}/checkpoints"
            try:
                os.makedirs(save_path)
            except OSError:
                pass
            encoder_state = encoder.state_dict()
            actor_critic_state = actor_critic.state_dict()
            torch.save(
                [encoder_state, actor_critic_state, j],
                f"{save_path}/ckpt.latest.pth",
            )
            if args.save_unique:
                torch.save(
                    [encoder_state, actor_critic_state, j],
                    f"{save_path}/ckpt.{(j+1):07d}.pth",
                )

        # =================== Logging data ====================
        total_num_steps = (j + 1 - j_start) * NPROC * args.num_steps
        if j % args.log_interval == 0:
            end = time.time()
            fps = int(total_num_steps / (end - start))
            print(f"===> Updates {j}, #steps {total_num_steps}, FPS {fps}")
            train_metrics = rl_losses
            train_metrics["exploration_rewards"] = (
                np.mean(episode_expl_rewards) * rec_reward_interval /
                args.num_steps)
            train_metrics["rec_rewards"] = (np.mean(episode_rec_rewards) *
                                            rec_reward_interval /
                                            args.num_steps)
            train_metrics["area_covered"] = np.mean(per_proc_area)
            train_metrics["objects_covered"] = np.mean(objects_tracker)
            train_metrics["landmarks_covered"] = np.mean(osr_tracker)
            train_metrics["collisions"] = np.mean(episode_collisions)
            train_metrics["novelty_rewards"] = np.mean(novelty_tracker)
            train_metrics["smooth_coverage_rewards"] = np.mean(
                smooth_coverage_tracker)
            for k, v in train_metrics.items():
                print(f"{k}: {v:.3f}")
                tbwriter.add_scalar(f"train_metrics/{k}", v, j)

        # =================== Evaluate models ====================
        if args.eval_interval is not None and (j +
                                               1) % args.eval_interval == 0:
            if "habitat" in args.env_name:
                devices = [
                    int(dev)
                    for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
                ]
                # Devices need to be indexed between 0 to N-1
                devices = [dev for dev in range(len(devices))]
                eval_envs = make_vec_envs_habitat(
                    args.eval_habitat_config_file, device, devices)
            else:
                eval_envs = make_vec_envs_avd(
                    args.env_name,
                    args.seed + 12,
                    12,
                    eval_log_dir,
                    device,
                    True,
                    split="val",
                    nRef=NREF,
                    set_return_topdown_map=True,
                )

            num_eval_episodes = 16 if "habitat" in args.env_name else 30

            eval_config = {}
            eval_config["num_steps"] = args.num_steps
            eval_config["feat_shape_sim"] = args.feat_shape_sim
            eval_config[
                "num_processes"] = 1 if "habitat" in args.env_name else 12
            eval_config["odometer_shape"] = args.odometer_shape
            eval_config["num_eval_episodes"] = num_eval_episodes
            eval_config["num_pose_refs"] = NREF
            eval_config["env_name"] = args.env_name
            eval_config["actor_type"] = "learned_policy"
            eval_config["encoder_type"] = args.encoder_type
            eval_config["use_action_embedding"] = args.use_action_embedding
            eval_config[
                "use_collision_embedding"] = args.use_collision_embedding
            eval_config["cluster_centroids"] = cluster_centroids
            eval_config["clusters2images"] = clusters2images
            eval_config["rec_loss_fn"] = rec_loss_fn_classify
            eval_config[
                "vis_save_dir"] = f"{args.save_dir}/policy_vis/update_{(j+1):05d}"
            models = {}
            models["decoder"] = decoder
            models["pose_encoder"] = pose_encoder
            models["feature_network"] = feature_network
            models["encoder"] = encoder
            models["actor_critic"] = actor_critic
            val_metrics, _ = evaluate_reconstruction(models, eval_envs,
                                                     eval_config, device)
            decoder.eval()
            pose_encoder.eval()
            feature_network.eval()
            actor_critic.train()
            encoder.train()
            for k, v in val_metrics.items():
                tbwriter.add_scalar(f"val_metrics/{k}", v, j)

    tbwriter.close()
def compute_pose_sptm(
    obs_feats_sim,
    obs_feats_pose,
    obs_odometer,
    ref_feats_sim,
    ref_feats_pose,
    config,
    models,
    device,
    env_name,
):
    """
    Inputs:
        obs_feats_sim  - (T, N, feat_size_sim)
        obs_feats_pose - (T, N, feat_size_pose)
        obs_odometer   - (T, N, 4)
        ref_feats_sim  - (N, nRef, feat_size_sim)
        ref_feats_pose - (N, nRef, feat_size_pose)

    Outputs:
        predicted_poses - (N, nref, num_poses) for ref_feats based
        on past T observations
    """

    map_shape = config["map_shape"]
    map_scale = config["map_scale"]
    bin_size = config["bin_size"]
    angles = config["angles"]
    median_filter_size = config["median_filter_size"]
    vote_kernel_size = config["vote_kernel_size"]
    match_thresh = config["match_thresh"]

    rnet = models["rnet"]
    posenet = models["posenet"]  # Pairwise pose predictor
    pose_head = models["pose_head"]

    gaussian_kernel = get_gaussian_kernel(kernel_size=vote_kernel_size,
                                          sigma=2.5,
                                          channels=1)
    gaussian_kernel = gaussian_kernel.to(device)

    # ========== Compute features for similarity prediction ===========
    T, N = obs_feats_sim.shape[:2]
    nRef = ref_feats_sim.shape[1]
    feat_size_sim = obs_feats_sim.shape[2]
    feat_size_pose = obs_feats_pose.shape[2]

    # ======== Compute the positions of each observation on the map ===========
    # obs_odometer ---> (T, N, 4) ---> (y, x, phi_head, phi_elev)
    obs_poses = torch.index_select(
        obs_odometer, 2,
        torch.LongTensor([1, 0,
                          2]).to(device))  # (T, N, 3) ---> (x, y, phi_head)

    # ========== Compute pairwise scores with all prior observations ==========
    all_pred_poses = []
    all_voting_maps = []
    all_topk_idxes = []
    all_paired_scores = []
    all_paired_poses_polar = []
    median_filter = MedianPool1d(median_filter_size, 1,
                                 median_filter_size // 2)
    median_filter.to(device)
    for t in range(T - 1, T):
        obs_feats_sim_curr = obs_feats_sim[:(t + 1)]  # (t+1, N, feat_size_sim)
        obs_feats_sim_curr = unsq_exp(obs_feats_sim_curr, nRef,
                                      dim=2)  # (t+1, N, nRef, feat_size_sim)
        ref_feats_sim_curr = unsq_exp(ref_feats_sim, t + 1,
                                      dim=0)  # (t+1, N, nRef, feat_size_sim)
        obs_feats_sim_curr = obs_feats_sim_curr.view(
            -1, feat_size_sim)  # ((t+1)*N*nRef, feat_size_sim)
        ref_feats_sim_curr = ref_feats_sim_curr.view(
            -1, feat_size_sim)  # ((t+1)*N*nRef, feat_size_sim)

        with torch.no_grad():
            paired_scores = rnet.compare(
                torch.cat([obs_feats_sim_curr, ref_feats_sim_curr],
                          dim=1))  # ((t+1)*N*nRef, 2)
            paired_scores = F.softmax(paired_scores,
                                      dim=1)[:, 1]  # ((t+1)*N*nRef, )
            paired_scores = paired_scores.view(t + 1,
                                               N * nRef)  # (t+1, N*nRef)
            # Apply median filtering
            paired_scores = rearrange(paired_scores,
                                      "t f -> f () t")  # (N*nRef, 1, t+1)
            paired_scores = median_filter(paired_scores)  # (N*nRef, 1, t+1)
            paired_scores = rearrange(paired_scores, "f () t -> t n r",
                                      n=N)  # (t+1, N, nRef)
        # Top K matches
        k = min(paired_scores.shape[0], 10)
        topk_scores, topk_idx = torch.topk(paired_scores, k=k,
                                           dim=0)  # (k, N, nRef)

        # Compute pose predictions for each match
        obs_poses_curr = obs_poses[:(t + 1)]  # (t+1, N, 3)
        obs_poses_curr = unsq_exp(obs_poses_curr, nRef,
                                  dim=2)  # (t+1, N, nRef, 3)
        obs_feats_pose_curr = unsq_exp(obs_feats_pose[:(t + 1)], nRef,
                                       dim=2)  # (t+1, N, nRef, feat_size_pose)
        topk_idx_exp = unsq_exp(topk_idx, feat_size_pose,
                                dim=3)  # (k, N, nRef, feat_size_pose)
        topk_obs_feats_pose = torch.gather(
            obs_feats_pose_curr, 0,
            topk_idx_exp)  # (k, N, nRef, feat_size_pose)
        topk_idx_exp = topk_idx.unsqueeze(3).expand(-1, -1, -1, 3)
        topk_obs_poses = torch.gather(obs_poses_curr, 0,
                                      topk_idx_exp)  # (k, N, nRef, 3)

        ref_feats_pose_k = unsq_exp(ref_feats_pose, k,
                                    dim=0)  # (k, N, nRef, feat_size_pose)
        topk_obs_feats_pose = topk_obs_feats_pose.view(
            -1, feat_size_pose)  # (k * N * nRef, feat_size_pose)
        topk_obs_poses = topk_obs_poses.view(-1, 3)  # (k * N * nRef, 3)
        ref_feats_pose_k = ref_feats_pose_k.view(
            -1, feat_size_pose)  # (k * N * nRef, feat_size_pose)

        with torch.no_grad():
            # (k * N * nRef, 3) ---> delta_x, delta_y, delta_theta
            pred_dposes = posenet.get_pose_xyt_feats(topk_obs_feats_pose,
                                                     ref_feats_pose_k)
            if "avd" in env_name:
                pred_dposes[:, :2] *= 1000.0  # (m -> mm)

        # Convert pred_dposes from observation centric coordinate
        # to the world coordinates.
        # (k * N * nRef, 3) ---> delta_x, delta_y, delta_theta
        pred_dposes_polar = xyt2polar(pred_dposes)
        pred_dposes_polar[:,
                          1] += topk_obs_poses[:,
                                               2]  # add the observation's world heading
        pred_dposes_world = polar2xyt(
            pred_dposes_polar)  # Convert delta pose to world coordinate
        pred_poses_world = (pred_dposes_world + topk_obs_poses
                            )  # Get real world pose in xyt system
        all_pred_poses.append(pred_poses_world.view(k, N, nRef, 3))

        # Create the voting map based on the predicted poses
        pred_poses_polar = xyt2polar(
            pred_poses_world)  # Convert to (r, phi, theta) coordinates
        pred_poses_map = process_poseref(pred_poses_polar, map_shape,
                                         map_scale, angles, bin_size /
                                         2).long()  # (k * N * nRef, 3)
        pred_poses_oh = torch.zeros(k * N * nRef, *map_shape).to(
            device)  # (k * N * nRef, 1, mh, mw)
        pred_poses_oh[range(k * N * nRef), 0, pred_poses_map[:, 1],
                      pred_poses_map[:, 0]] = 1
        with torch.no_grad():
            pred_poses_smooth = gaussian_kernel(pred_poses_oh)
        pred_poses_smooth = pred_poses_smooth.view(
            k, N, nRef, *map_shape)  # (k, N, nRef, 1, mh, mw)

        # Top K matches filtered by match threshold
        thresh_filter = (topk_scores > match_thresh).float()  #  (k, N, nRef)
        thresh_filter = rearrange(thresh_filter, "k n r -> k n r () () ()")
        pred_poses_smooth = pred_poses_smooth * thresh_filter

        voting_map = pred_poses_smooth  # (k, N, nRef, 1, mh, mw)

        all_voting_maps.append(voting_map)
        all_topk_idxes.append(topk_idx)
        all_paired_scores.append(paired_scores)
        all_paired_poses_polar.append(pred_poses_polar.cpu().view(
            k, N, nRef, 3))

    all_pred_poses = torch.cat(all_pred_poses, dim=1)  # (k, N, nRef, 3)
    all_voting_maps = torch.cat(all_voting_maps,
                                dim=1)  # (k, N, nRef, 1, mh, mw)
    all_topk_idxes = torch.cat(all_topk_idxes, dim=0)  # (10, N, nRef)
    all_paired_scores = torch.cat(all_paired_scores, dim=0)  # (T, N, nRef)
    all_paired_poses_polar = torch.cat(all_paired_poses_polar,
                                       dim=0)  # (k, N, nRef, 3)

    # ========== Predict pose ============
    all_voting_maps = rearrange(
        all_voting_maps,
        "k n r c h w -> k (n r) c h w")  # (k, N*nRef, 1, mh, mw)
    predicted_poses, novote_masks = pose_head.forward(
        all_voting_maps)  # (N*nRef, num_poses)
    predicted_poses = unflatten_two(predicted_poses, N,
                                    nRef)  # (N, nRef, num_poses)
    novote_masks = unflatten_two(novote_masks, N, nRef)  # (N, nRef)

    all_pred_pose_angles = (all_pred_poses[:, :, :,
                                           2].view(k, N * nRef).unsqueeze(2)
                            )  # (k, N*nRef, 1)
    predicted_positions = pose_head.get_position_and_pose(
        all_voting_maps, all_pred_pose_angles)
    predicted_positions = unflatten_two(predicted_positions, N,
                                        nRef)  # (N, nRef, 3) ---> (x, y, t)

    outputs = {}
    outputs["predicted_poses"] = predicted_poses  # (N, nRef, num_poses)
    outputs["predicted_positions"] = predicted_positions  # (N, nRef, 3)
    outputs["novote_masks"] = novote_masks  # (N, nRef)

    return outputs
def compute_pose_sptm_ransac(
    obs_feats_sim,
    obs_feats_pose,
    obs_odometer,
    ref_feats_sim,
    ref_feats_pose,
    config,
    models,
    device,
    env_name,
):
    """
    Given a history of observations in the form of features, odometer readings,
    estimate the location of a set of references given their features.

    Inputs:
        obs_feats_sim  - (T, N, feat_size_sim)
        obs_feats_pose - (T, N, feat_size_pose)
        obs_odometer   - (T, N, 4)
        ref_feats_sim  - (N, nRef, feat_size_sim)
        ref_feats_pose - (N, nRef, feat_size_pose)

    Outputs:
        predicted_poses      - (N, nRef, num_poses)
        predicted_positions  - (N, nRef, 3)
        novote_masks         - (N, nRef)
        all_paired_poses_map - (T, N, nRef, 3)
        obs_poses            - (T, N, 3)
        inlier_mask          - (T, N, nRef)
        all_pairwise_scores  - (T, N, nRef)
        voting_maps          - (N, nRef, 1, mh, mw)
    """

    map_shape = config["map_shape"]
    map_scale = config["map_scale"]
    bin_size = config["bin_size"]
    angles = config["angles"]
    median_filter_size = config["median_filter_size"]
    match_thresh = config["match_thresh"]

    rnet = models["rnet"]
    posenet = models["posenet"]  # Pairwise pose predictor
    pose_head = models["pose_head"]
    ransac_estimator = models["ransac_estimator"]

    # ========== Compute features for similarity prediction ===========
    T, N = obs_feats_sim.shape[:2]
    nRef = ref_feats_sim.shape[1]
    feat_size_sim = obs_feats_sim.shape[2]

    # ========== Compute the positions of each observation on the map ============
    # obs_odometer ---> (T, N, 4) ---> (y, x, phi_head, phi_elev)
    obs_poses = torch.index_select(
        obs_odometer, 2,
        torch.LongTensor([1, 0,
                          2]).to(device))  # (T, N, 3) ---> (x, y, phi_head)

    # ========== Compute pairwise similarity with all prior observations ============
    median_filter = MedianPool1d(median_filter_size, 1,
                                 median_filter_size // 2)
    median_filter.to(device)
    obs_feats_sim = repeat(obs_feats_sim, "t n f -> (t n r) f", r=nRef)
    ref_feats_sim = repeat(ref_feats_sim, "n r f -> (t n r) f", t=T)
    with torch.no_grad():
        paired_scores = rnet.compare(
            torch.cat([obs_feats_sim, ref_feats_sim], dim=1))
    paired_scores = F.softmax(paired_scores, dim=1)[:, 1]  # (T*N*nRef, )
    paired_scores = rearrange(paired_scores, "(t n r) -> t (n r)", t=T, n=N)
    if paired_scores.shape[0] > 1:
        # Apply median filtering
        paired_scores = rearrange(paired_scores, "t nr -> nr () t")
        paired_scores = median_filter(paired_scores)
        paired_scores = rearrange(paired_scores, "nr () t -> t nr")
    paired_scores = rearrange(paired_scores, "t (n r) -> t n r", t=T, n=N)

    # ========== Compute pairwise poses with all prior observations ============
    ref_feats_pose = repeat(ref_feats_pose, "n r f -> (t n r) f", t=T)
    obs_feats_pose = repeat(obs_feats_pose, "t n f -> (t n r) f", r=nRef)
    with torch.no_grad():
        pairwise_dposes = posenet.get_pose_xyt_feats(obs_feats_pose,
                                                     ref_feats_pose)
        if "avd" in env_name:
            pairwise_dposes[:, :2] *= 1000.0  # (m -> mm)

    # ============= Add pairwise delta to observation pose ==============
    obs_poses_rep = repeat(obs_poses, "t n p -> (t n r) p", r=nRef)
    pairwise_poses_world = add_pose(obs_poses_rep, pairwise_dposes, mode="xyt")
    pairwise_poses_world = pairwise_poses_world.view(T, N, nRef,
                                                     3)  # (x, y, t)

    # ========== Define similarity weighted sampling function ==========
    paired_scores = rearrange(paired_scores, "t n r -> t (n r)")
    # When no samples fall above the match_thresh, set match_thresh to a lower value
    match_thresh_mask = (paired_scores > match_thresh).sum(
        dim=0) == 0  # (N*nRef, )
    batch_match_thresh = torch.ones(N * nRef).to(device) * match_thresh
    # If any element has zero samples above match threshold
    if match_thresh_mask.sum().item() > 0:
        batch_match_thresh[match_thresh_mask] = (
            paired_scores[:, match_thresh_mask].max(dim=0)[0] - 0.001
        )  # (N*nRef, )
    batch_match_thresh = batch_match_thresh.unsqueeze(0)  # (1, N*nRef)
    # Compute mask indicating validity of samples along time
    valid_masks = paired_scores > batch_match_thresh  # (T, N*nRef)

    # Assign zero weights to observations below a matching threshold
    sample_weights = (paired_scores *
                      (paired_scores > batch_match_thresh).float()
                      )  # (T, N*nRef)
    pairwise_poses_world = rearrange(pairwise_poses_world,
                                     "t n r p -> t (n r) p")
    (
        pred_pose_inliers,
        pred_position_inliers,
        voting_map_inliers,
        inlier_mask,
    ) = ransac_estimator.ransac_pose_estimation(pairwise_poses_world,
                                                sample_weights, valid_masks)
    novote_masks = match_thresh_mask
    # pred_pose_inliers - (N*nRef, num_poses), pred_position_inliers - (N*nRef, 3)
    # novote_masks - (N*nRef, ), voting_map_inliers - (N*nRef, 1, mh, mw)
    pairwise_poses_map = ransac_estimator.polar2map(
        xyt2polar(flatten_two(pairwise_poses_world)))  # (T*N*nRef, 3)
    pred_pose_inliers = unflatten_two(pred_pose_inliers, N,
                                      nRef)  # (N, nRef, num_poses)
    pred_position_inliers = unflatten_two(pred_position_inliers, N,
                                          nRef)  # (N, nRef, 3)
    pairwise_poses_map = pairwise_poses_map.view(T, N, nRef,
                                                 3)  # (T, N, nRef, 3)
    obs_poses = obs_poses  # (T, N, 3)
    novote_masks = novote_masks.view(N, nRef)  # (N, nRef)
    inlier_mask = inlier_mask.view(T, N, nRef)  # (T, N, nRef)
    pairwise_scores = paired_scores.view(T, N, nRef)  # (T, N, nRef)
    final_voting_maps = unflatten_two(voting_map_inliers, N,
                                      nRef)  # (N, nRef, 1, mh, mw)

    outputs = {}
    outputs["predicted_poses"] = pred_pose_inliers
    outputs["predicted_positions"] = pred_position_inliers
    outputs["novote_masks"] = novote_masks
    outputs["all_paired_poses_map"] = pairwise_poses_map
    outputs["obs_poses"] = obs_poses
    outputs["inlier_mask"] = inlier_mask
    outputs["all_pairwise_scores"] = pairwise_scores
    outputs["voting_maps"] = final_voting_maps
    outputs[
        "successful_votes"] = inlier_mask  # The observations which successfully voted

    return outputs
Ejemplo n.º 8
0
    def update(self, rollouts):
        T, N, nfeats = rollouts.obs_feats[:-1].shape
        nRef = rollouts.tgt_feats.shape[1]
        device = rollouts.obs_feats.device
        avg_loss = 0.0
        avg_loss_count = 0.0
        tgt_feats = rollouts.tgt_feats  # (N, nRef, nfeats)
        tgt_masks = rollouts.tgt_masks.squeeze(2)  # (N, nRef)
        obs_feats = unsq_exp(rollouts.obs_feats, nRef,
                             dim=2)  # (T+1, N, nRef, nfeats)
        obs_poses = unsq_exp(rollouts.obs_odometer[:, :, :3], nRef,
                             dim=2)  # (T+1, N, nRef, 3) - (y, x, phi)
        tgt_poses = unsq_exp(rollouts.tgt_poses, T + 1,
                             dim=0)  # (T+1, N, nRef, 3)
        # Make a prediction after every prediction_interval steps, i.e.,
        # the agent has seen self.prediction_interval*(i+1) observations.
        for i in range(0, T, self.prediction_interval):
            L = min(i + self.prediction_interval, T)
            # Estimate relative pose b/w targets and observations.
            obs_relpose = subtract_pose(
                rearrange(tgt_poses[:L], "l b n f -> (l b n) f"),
                rearrange(obs_poses[:L], "l b n f -> (l b n) f"),
            )  # (L*N*nRef, 3) --- (x, y, phi)
            # ========================= Forward pass ==========================
            # Encode the poses of the observations and targets.
            obs_relpose_enc = self.pose_encoder(obs_relpose)  # (L*N*nRef, 16)
            obs_relpose_enc = obs_relpose_enc.view(L, N * nRef, -1)
            tgt_relpose_enc = torch.zeros(
                1, *obs_relpose_enc.shape[1:]).to(device)
            obs_feats_i = rearrange(obs_feats[:L], "l b n f -> l (b n) f")
            # These serve as inputs to an encoder-decoder transformer model.
            rec_inputs = {
                # encoder inputs
                "history_image_features": obs_feats_i,  # (L, N*nRef, nfeats)
                "history_pose_features": obs_relpose_enc,  # (L, N*nRef, 16)
                # decoder inputs
                "target_pose_features": tgt_relpose_enc,  # (1, N*nRef, 16)
            }
            pred_logits = self.decoder(rec_inputs).squeeze(
                0)  # (N*nRef, nclass)
            # =================== Compute reconstruction loss =================
            loss = self.rec_loss_fn(
                pred_logits,  # (N*nRef, nclass)
                flatten_two(tgt_feats),  # (N*nRef, nfeats)
                self.cluster_centroids,
                K=self.rec_loss_fn_J,
                reduction="none",
            ).sum(dim=1)  # (N*nRef, )
            loss = unflatten_two(loss, N, nRef)
            # Mask out invalid targets.
            loss = loss * tgt_masks
            loss = loss.mean()
            # ========================== Backward pass ========================
            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(
                itertools.chain(
                    self.decoder.parameters(),
                    self.pose_encoder.parameters(),
                ),
                self.max_grad_norm,
            )
            self.optimizer.step()

            avg_loss += loss.item()
            avg_loss_count += 1.0

        avg_loss = avg_loss / avg_loss_count
        losses = {"rec_loss": avg_loss}
        return losses