def estimate_pose(self, pairwise_poses_world): """ Inputs: pairwise_poses_world - (K, bs, 3) - (x, y, theta) in world coordinates Outputs: final_pose - (bs, 3) -- pose final_position- (bs, 3) -- (x, y, theta) in world coordinates """ K, bs = pairwise_poses_world.shape[:2] # Create the voting map based on the predicted poses # Convert to (r, phi, theta) coordinates pred_poses_polar = xyt2polar( flatten_two(pairwise_poses_world)) # (K*bs, 3) pred_poses_map = self.polar2map(pred_poses_polar) # (K*bs, 3) voting_maps = self.map2votes(pred_poses_map) # (K*bs, 1, mh, mw) pairwise_angles = flatten_two(pairwise_poses_world[:, :, 2]) # (K*bs, ) voting_maps = unflatten_two(voting_maps, K, bs) # (K, bs, 1, mh, mw) pairwise_angles = pairwise_angles.view(K, bs, 1) # (K, bs, 1) final_pose, _ = self.pose_head.forward( voting_maps) # (bs, num_poses), (bs, ) final_position = self.pose_head.get_position_and_pose( voting_maps, pairwise_angles) # (bs, 3) return final_pose, final_position
def update(self, rollouts): """Update the policy based on expert data in the rollouts. """ T, N = rollouts.actions.shape[:2] expert_actions = rollouts.actions # (T, N, 1) # Masks indicating when expert actions were *not* taken. This permits # a form of data augmentation where non-expert actions are taken to # accomodate distribution shifts b/w the expert and the learned policy. action_masks = rollouts.action_masks # (T, N, 1) hxs = rollouts.recurrent_hidden_states[0].unsqueeze( 0) # (1, N, nfeats) masks = rollouts.masks[:-1] # (T, N, nfeats) # ============= Update inflection factor if applicable ================ if self.use_inflection_weighting: inflection_mask = self._get_inflection_mask(expert_actions) # Inverse frequency of inflection points. inflection_factor = T / (inflection_mask.sum(dim=0) + 1e-12) inflection_factor = torch.clamp(inflection_factor, 1.0, self.trunc_factor_clipping) self._update_inflection_factor(inflection_factor.mean().item()) # ========================= Forward pass ============================== hxs = flatten_two(hxs) # (N, nfeats) masks = flatten_two(masks) # (T*N, nfeats) action_masks = flatten_two(action_masks).squeeze(1) # (T*N, ) policy_inputs = self._create_policy_inputs(rollouts) # (T*N, nactions) pred_action_log_probs = self.actor_critic.get_log_probs( policy_inputs, hxs, masks) # ==================== Compute the prediction loss ==================== expert_actions = flatten_two(expert_actions).squeeze( 1).long() # (T*N,) action_loss = F.nll_loss(pred_action_log_probs, expert_actions, reduction="none") # (T*N, ) # Weight the loss based on inflection points. if self.use_inflection_weighting: inflection_mask = flatten_two(inflection_mask).squeeze(1) # (T*N,) action_loss = action_loss * ( inflection_mask * self.inflection_factor + (1 - inflection_mask) * 1.0) # Mask the losses for non-expert actions. action_loss = (action_loss * action_masks).sum() / (action_masks.sum() + 1e-10) # ============================ Backward pass ========================== self.optimizer.zero_grad() action_loss.backward() nn.utils.clip_grad_norm_( chain(self.encoder.parameters(), self.actor_critic.parameters()), self.max_grad_norm, ) self.optimizer.step() losses = {} losses["action_loss"] = action_loss.item() return losses
def _create_policy_inputs(self, rollouts): """The policy inputs consist of features extract from the RGB and top-down occupancy maps, and learned encodings of the previous actions, and collision detections. """ obs_im = rollouts.obs_im[:-1] # (T, N, *obs_shape) encoder_inputs = [obs_im] if self.encoder_type == "rgb+map": encoder_inputs.append(rollouts.obs_sm[:-1]) # (T, N, *obs_shape) encoder_inputs.append(rollouts.obs_lm[:-1]) # (T, N, *obs_shape) encoder_inputs = [flatten_two(v) for v in encoder_inputs] obs_feats = self.encoder(*encoder_inputs) # (T*N, nfeats) policy_inputs = {"features": obs_feats} if self.use_action_embedding: prev_actions = torch.zeros_like(rollouts.actions) # (T, N, 1) prev_actions[1:] = rollouts.actions[:-1] prev_actions = flatten_two(prev_actions) # (T*N, 1) policy_inputs["actions"] = prev_actions.long() if self.use_collision_embedding: prev_collisions = flatten_two(rollouts.collisions[:-1]) # (T*N, 1) policy_inputs["collisions"] = prev_collisions.long() return policy_inputs
def estimate_pose_mask(self, pairwise_poses_world, masks): """ Inputs: pairwise_poses_world - (K, bs, 3) - (x, y, theta) in world coordinates masks - (K, bs) - binary indicating which of the K elements are to be considered Outputs: final_pose - (bs, 3) -- pose final_position- (bs, 3) -- (x, y, theta) in world coordinates voting_map - (bs, 1, mh, mw) """ K, bs = pairwise_poses_world.shape[:2] # Create the voting map based on the predicted poses # Convert to (r, phi, theta) coordinates pred_poses_polar = xyt2polar( flatten_two(pairwise_poses_world)) # (K*bs, 3) pred_poses_map = self.polar2map(pred_poses_polar) # (K*bs, 3) voting_maps = self.map2votes(pred_poses_map) # (K*bs, 1, mh, mw) pairwise_angles = flatten_two(pairwise_poses_world[:, :, 2]) # (K*bs, ) # Mask out the irrelevant samples voting_maps = voting_maps * masks.view(K * bs, 1, 1, 1) voting_maps = unflatten_two(voting_maps, K, bs) # (K, bs, 1, mh, mw) pairwise_angles = pairwise_angles.view(K, bs, 1) # (K, bs, 1) final_pose, _ = self.pose_head.forward( voting_maps) # (bs, num_poses), (bs, ) final_position = self.pose_head.get_position_and_pose( voting_maps, pairwise_angles) # (bs, 3) voting_maps = voting_maps.sum(dim=0) # (bs, 1, mh, mw) voting_maps = voting_maps + 1e-9 # Add small non-zero votes to all locations voting_sum = voting_maps.view(bs, -1).sum(dim=1) # (bs, ) voting_sum = voting_sum.view(bs, 1, 1, 1) # (bs, 1, 1, 1) voting_map = voting_maps / voting_sum # (bs, 1, mh, mw) return final_pose, final_position, voting_map
def compute_reconstruction_rewards( obs_feats, obs_odometer, tgt_feats, tgt_poses, cluster_centroids_t, decoder, pose_encoder, ): """ Inputs: obs_feats - (T, N, nclusters) obs_odometer - (T, N, 3) --- (y, x, theta) tgt_feats - (N, nRef, nclusters) tgt_poses - (N, nRef, 3) --- (y, x, theta) cluster_centroids_t - (nclusters, feat_dim) decoder - decoder model pose_encoder - pose_encoder model Outputs: reward - (N, nRef) float values indicating how many GT clusters were successfully retrieved for each target. """ T, N, nclusters = obs_feats.shape nRef = tgt_feats.shape[1] device = obs_feats.device obs_feats_exp = obs_feats.unsqueeze(2) obs_feats_exp = obs_feats_exp.expand( -1, -1, nRef, -1).contiguous() # (T, N, nRef, nclusters) obs_odometer_exp = obs_odometer.unsqueeze(2) obs_odometer_exp = obs_odometer_exp.expand( -1, -1, nRef, -1).contiguous() # (T, N, nRef, 3) tgt_poses_exp = (tgt_poses.unsqueeze(0).expand(T, -1, -1, -1).contiguous() ) # (T, N, nRef, 3) # Compute relative poses obs_odometer_exp = obs_odometer_exp.view(T * N * nRef, 3) tgt_poses_exp = tgt_poses_exp.view(T * N * nRef, 3) obs_relpose = subtract_pose(obs_odometer_exp, tgt_poses_exp) # (T*N*nRef, 3) --- (x, y, phi) # Compute pose encoding with torch.no_grad(): obs_relpose_enc = pose_encoder(obs_relpose) # (T*N*nRef, 16) obs_relpose_enc = obs_relpose_enc.view(T, N, nRef, -1) # (T, N, nRef, 16) tgt_relpose_enc = torch.zeros(1, *obs_relpose_enc.shape[1:]).to( device) # (1, N, nRef, 16) # Compute reconstructions obs_feats_exp = obs_feats_exp.view(T, N * nRef, nclusters) obs_relpose_enc = obs_relpose_enc.view(T, N * nRef, -1) tgt_relpose_enc = tgt_relpose_enc.view(1, N * nRef, -1) rec_inputs = { "history_image_features": obs_feats_exp, "history_pose_features": obs_relpose_enc, "target_pose_features": tgt_relpose_enc, } with torch.no_grad(): pred_logits = decoder(rec_inputs) # (1, N*nRef, nclusters) pred_logits = pred_logits.squeeze(0) # (N*nRef, nclusters) pred_logits = unflatten_two(pred_logits, N, nRef) # (N, nRef, nclusters) # Compute GT classes tgt_feats_sim = tgt_feats # (N, nRef, nclusters) topk_gt = torch.topk(tgt_feats_sim, 5, dim=2) topk_gt_values = topk_gt.values # (N, nRef, nclusters) topk_gt_thresh = topk_gt_values.min(dim=2).values # (N, nRef) # ------------------ KL Div loss based reward -------------------- reward = -rec_loss_fn_classify( flatten_two(pred_logits), flatten_two(tgt_feats), cluster_centroids_t.t(), K=2, reduction="none", ).sum(dim=1) # (N*nRef, ) reward = reward.view(N, nRef) return reward
start = time.time() NPROC = args.num_processes NREF = args.num_pose_refs for j in range(j_start + 1, num_updates): # =================== Start a new episode ==================== obs = envs.reset() # Processing environment inputs obs_im = get_obs(obs) # (num_processes, 3, 84, 84) obs_odometer = process_odometer(obs["delta"]) # (num_processes, 4) # Convert mm to m for AVD if "avd" in args.env_name: obs_odometer[:, :2] /= 1000.0 # ============== Target poses and corresponding images ================ # NOTE - these are constant throughout the episode. # (num_processes * num_pose_refs, 3) --- (y, x, t) tgt_poses = process_odometer(flatten_two(obs["pose_regress"]))[:, :3] tgt_poses = unflatten_two(tgt_poses, NPROC, NREF) # (N, nRef, 3) tgt_masks = obs["valid_masks"].unsqueeze(2) # (N, nRef, 1) # Convert mm to m for AVD if "avd" in args.env_name: tgt_poses[:, :, :2] /= 1000.0 tgt_ims = process_image(flatten_two(obs["pose_refs"])) # (N*nRef, C, H, W) # Initialize the memory of rollouts rollouts.reset() with torch.no_grad(): obs_feat = feature_network(obs_im) # (N, 2048) tgt_feat = feature_network(tgt_ims) # (N*nRef, 2048) # Compute similarity scores with all other clusters obs_feat = torch.matmul(obs_feat, cluster_centroids.t()) # (N, nclusters) tgt_feat = torch.matmul( tgt_feat, cluster_centroids.t()
def main(): torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") ndevices = torch.cuda.device_count() # Setup loggers tbwriter = SummaryWriter(log_dir=args.log_dir) logging.basicConfig(filename=f"{args.log_dir}/train_log.txt", level=logging.DEBUG) logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.getLogger().setLevel(logging.INFO) if "habitat" in args.env_name: devices = [ int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",") ] # Devices need to be indexed between 0 to N-1 devices = [dev for dev in range(len(devices))] envs = make_vec_envs_habitat(args.habitat_config_file, device, devices, seed=args.seed) else: train_log_dir = os.path.join(args.log_dir, "train_monitor") try: os.makedirs(train_log_dir) except OSError: pass envs = make_vec_envs_avd( args.env_name, args.seed, args.num_processes, train_log_dir, device, True, num_frame_stack=1, split="train", nRef=args.num_pose_refs, ) args.feat_shape_sim = (512, ) args.obs_shape = envs.observation_space.spaces["im"].shape args.odometer_shape = (4, ) # (delta_y, delta_x, delta_head, delta_elev) # =================== Load clusters ================= clusters_h5 = h5py.File(args.clusters_path, "r") cluster_centroids = torch.Tensor(np.array( clusters_h5["cluster_centroids"])).to(device) cluster_centroids_t = cluster_centroids.t() args.nclusters = cluster_centroids.shape[0] clusters2images = {} for i in range(args.nclusters): cluster_images = np.array( clusters_h5[f"cluster_{i}/images"]) # (K, C, H, W) torch Tensor cluster_images = rearrange(cluster_images, "k c h w -> k h w c") cluster_images = (cluster_images * 255.0).astype(np.uint8) clusters2images[i] = cluster_images # (K, H, W, C) clusters_h5.close() # =================== Create models ==================== decoder = FeatureReconstructionModule( args.nclusters, args.nclusters, nlayers=args.n_transformer_layers, ) feature_network = FeatureNetwork() pose_encoder = PoseEncoder() encoder = RGBEncoder() if args.encoder_type == "rgb" else MapRGBEncoder() action_config = ({ "nactions": envs.action_space.n, "embedding_size": args.action_embedding_size } if args.use_action_embedding else None) collision_config = ({ "collision_dim": 2, "embedding_size": args.collision_embedding_size } if args.use_collision_embedding else None) actor_critic = Policy( envs.action_space, base_kwargs={ "feat_dim": args.feat_shape_sim[0], "recurrent": True, "hidden_size": args.feat_shape_sim[0], "action_config": action_config, "collision_config": collision_config, }, ) # =================== Load models ==================== decoder_state, pose_encoder_state = torch.load(args.load_path_rec)[:2] # Remove DataParallel related strings new_decoder_state, new_pose_encoder_state = {}, {} for k, v in decoder_state.items(): new_decoder_state[k.replace("module.", "")] = v for k, v in pose_encoder_state.items(): new_pose_encoder_state[k.replace("module.", "")] = v decoder.load_state_dict(new_decoder_state) pose_encoder.load_state_dict(new_pose_encoder_state) decoder = nn.DataParallel(decoder, dim=1) pose_encoder = nn.DataParallel(pose_encoder, dim=0) save_path = os.path.join(args.save_dir, "checkpoints") checkpoint_path = os.path.join(save_path, "ckpt.latest.pth") if os.path.isfile(checkpoint_path): print("Resuming from old model!") loaded_states = torch.load(checkpoint_path) encoder_state, actor_critic_state, j_start = loaded_states encoder.load_state_dict(encoder_state) actor_critic.load_state_dict(actor_critic_state) elif args.pretrained_il_model != "": logging.info("Initializing with pre-trained model!") encoder_state, actor_critic_state, _ = torch.load( args.pretrained_il_model) actor_critic.load_state_dict(actor_critic_state) encoder.load_state_dict(encoder_state) j_start = -1 else: j_start = -1 encoder.to(device) actor_critic.to(device) decoder.to(device) feature_network.to(device) pose_encoder.to(device) encoder.train() actor_critic.train() # decoder, feature_network, pose_encoder are frozen during policy training decoder.eval() feature_network.eval() pose_encoder.eval() # =================== Define RL training algorithm ==================== rl_algo_config = {} rl_algo_config["lr"] = args.lr rl_algo_config["eps"] = args.eps rl_algo_config["encoder_type"] = args.encoder_type rl_algo_config["max_grad_norm"] = args.max_grad_norm rl_algo_config["clip_param"] = args.clip_param rl_algo_config["ppo_epoch"] = args.ppo_epoch rl_algo_config["entropy_coef"] = args.entropy_coef rl_algo_config["num_mini_batch"] = args.num_mini_batch rl_algo_config["value_loss_coef"] = args.value_loss_coef rl_algo_config["use_clipped_value_loss"] = False rl_algo_config["nactions"] = envs.action_space.n rl_algo_config["encoder"] = encoder rl_algo_config["actor_critic"] = actor_critic rl_algo_config["use_action_embedding"] = args.use_action_embedding rl_algo_config["use_collision_embedding"] = args.use_collision_embedding rl_agent = PPO(rl_algo_config) # =================== Define rollouts ==================== rollouts_recon = RolloutStorageReconstruction( args.num_steps, args.num_processes, (args.nclusters, ), args.odometer_shape, args.num_pose_refs, ) rollouts_policy = RolloutStoragePPO( args.num_rl_steps, args.num_processes, args.obs_shape, envs.action_space, args.feat_shape_sim[0], encoder_type=args.encoder_type, ) rollouts_recon.to(device) rollouts_policy.to(device) def get_obs(obs): obs_im = process_image(obs["im"]) if args.encoder_type == "rgb+map": obs_lm = process_image(obs["coarse_occupancy"]) obs_sm = process_image(obs["fine_occupancy"]) else: obs_lm = None obs_sm = None return obs_im, obs_sm, obs_lm start = time.time() NPROC = args.num_processes NREF = args.num_pose_refs for j in range(j_start + 1, num_updates): # =================== Start a new episode ==================== obs = envs.reset() # Processing environment inputs obs_im, obs_sm, obs_lm = get_obs(obs) # (num_processes, 3, 84, 84) obs_odometer = process_odometer(obs["delta"]) # (num_processes, 4) # Convert mm to m for AVD if "avd" in args.env_name: obs_odometer[:, :2] /= 1000.0 obs_collns = obs["collisions"].long() # (num_processes, 1) # ============== Target poses and corresponding images ================ # NOTE - these are constant throughout the episode. # (num_processes * num_pose_refs, 3) --- (y, x, t) tgt_poses = process_odometer(flatten_two(obs["pose_regress"]))[:, :3] tgt_poses = unflatten_two(tgt_poses, NPROC, NREF) # (N, nRef, 3) tgt_masks = obs["valid_masks"].unsqueeze(2) # (N, nRef, 1) # Convert mm to m for AVD if "avd" in args.env_name: tgt_poses[:, :, :2] /= 1000.0 tgt_ims = process_image(flatten_two( obs["pose_refs"])) # (N*nRef, C, H, W) # Initialize the memory of rollouts for reconstruction rollouts_recon.reset() with torch.no_grad(): obs_feat = feature_network(obs_im) # (N, 2048) tgt_feat = feature_network(tgt_ims) # (N*nRef, 2048) # Compute similarity scores with all other clusters obs_feat = torch.matmul(obs_feat, cluster_centroids_t) # (N, nclusters) tgt_feat = torch.matmul(tgt_feat, cluster_centroids_t) # (N*nRef, nclusters) tgt_feat = unflatten_two(tgt_feat, NPROC, NREF) # (N, nRef, nclusters) rollouts_recon.obs_feats[0].copy_(obs_feat) rollouts_recon.obs_odometer[0].copy_(obs_odometer) rollouts_recon.tgt_poses.copy_(tgt_poses) rollouts_recon.tgt_feats.copy_(tgt_feat) rollouts_recon.tgt_masks.copy_(tgt_masks) # Initialize the memory of rollouts for policy rollouts_policy.reset() rollouts_policy.obs_im[0].copy_(obs_im) if args.encoder_type == "rgb+map": rollouts_policy.obs_sm[0].copy_(obs_sm) rollouts_policy.obs_lm[0].copy_(obs_lm) rollouts_policy.collisions[0].copy_(obs_collns) # Episode statistics episode_expl_rewards = np.zeros((NPROC, 1)) episode_collisions = np.zeros((NPROC, 1)) episode_rec_rewards = np.zeros((NPROC, 1)) episode_collisions += obs_collns.cpu().numpy() # Metrics osr_tracker = [0.0 for _ in range(NPROC)] objects_tracker = [0.0 for _ in range(NPROC)] area_tracker = [0.0 for _ in range(NPROC)] novelty_tracker = [0.0 for _ in range(NPROC)] smooth_coverage_tracker = [0.0 for _ in range(NPROC)] per_proc_area = [0.0 for _ in range(NPROC)] # Other states prev_action = torch.zeros(NPROC, 1).long().to(device) prev_collision = rollouts_policy.collisions[0] rec_reward_interval = args.rec_reward_interval prev_rec_rewards = torch.zeros(NPROC, 1) # (N, 1) prev_rec_rewards = prev_rec_rewards.to(device) rec_rewards_at_t0 = None # ================= Update over a full batch of episodes ================= # num_steps must be total number of steps in each episode for step in range(args.num_steps): pstep = rollouts_policy.step with torch.no_grad(): encoder_inputs = [rollouts_policy.obs_im[pstep]] if args.encoder_type == "rgb+map": encoder_inputs.append(rollouts_policy.obs_sm[pstep]) encoder_inputs.append(rollouts_policy.obs_lm[pstep]) obs_feats = encoder(*encoder_inputs) policy_inputs = {"features": obs_feats} if args.use_action_embedding: policy_inputs["actions"] = prev_action.long() if args.use_collision_embedding: policy_inputs["collisions"] = prev_collision.long() policy_outputs = actor_critic.act( policy_inputs, rollouts_policy.recurrent_hidden_states[pstep], rollouts_policy.masks[pstep], ) ( value, action, action_log_probs, recurrent_hidden_states, ) = policy_outputs # Act, get reward and next obs obs, reward, done, infos = envs.step(action) # Processing environment inputs obs_im, obs_sm, obs_lm = get_obs(obs) # (num_processes, 3, 84, 84) obs_odometer = process_odometer(obs["delta"]) # (num_processes, 4) if "avd" in args.env_name: obs_odometer[:, :2] /= 1000.0 obs_collns = obs["collisions"] # (N, 1) with torch.no_grad(): obs_feat = feature_network(obs_im) # Compute similarity scores with all other clusters obs_feat = torch.matmul(obs_feat, cluster_centroids_t) # (N, nclusters) # Always set masks to 1 (since this loop happens within one episode) masks = torch.FloatTensor([[1.0] for _ in range(NPROC)]).to(device) # Accumulate odometer readings to give relative pose from the starting point obs_odometer = rollouts_recon.obs_odometer[ step] * masks + obs_odometer # Update rollouts_recon rollouts_recon.insert(obs_feat, obs_odometer) # Compute the exploration rewards reward_exploration = torch.zeros(NPROC, 1) # (N, 1) for proc in range(NPROC): seen_area = float(infos[proc]["seen_area"]) objects_visited = infos[proc].get("num_objects_visited", 0.0) oracle_success = float(infos[proc]["oracle_pose_success"]) novelty_reward = infos[proc].get("count_based_reward", 0.0) smooth_coverage_reward = infos[proc].get( "coverage_novelty_reward", 0.0) area_reward = seen_area - area_tracker[proc] objects_reward = objects_visited - objects_tracker[proc] landmarks_reward = oracle_success - osr_tracker[proc] collision_reward = -obs_collns[proc, 0].item() area_tracker[proc] = seen_area objects_tracker[proc] = objects_visited osr_tracker[proc] = oracle_success per_proc_area[proc] = seen_area novelty_tracker[proc] += novelty_reward smooth_coverage_tracker[proc] += smooth_coverage_reward # Compute reconstruction rewards if (step + 1) % rec_reward_interval == 0 or step == 0: rec_rewards = compute_reconstruction_rewards( rollouts_recon.obs_feats[:(step + 1)], rollouts_recon.obs_odometer[:(step + 1), :, :3], rollouts_recon.tgt_feats, rollouts_recon.tgt_poses, cluster_centroids_t, decoder, pose_encoder, ).detach() # (N, nRef) rec_rewards = rec_rewards * tgt_masks.squeeze(2) # (N, nRef) rec_rewards = rec_rewards.sum(dim=1).unsqueeze( 1) # / (tgt_masks.sum(dim=1) + 1e-8) final_rec_rewards = rec_rewards - prev_rec_rewards # if step == 0: # print( # "===============================================================" # ) # Ignore the exploration reward at T=0 since it will be a huge spike if (("avd" in args.env_name) and (step != 0)) or (("habitat" in args.env_name) and (step > 20)): # print( # "Rec rewards[0]: {:.2f}".format(final_rec_rewards[0, 0].item()) # ) reward_exploration += (final_rec_rewards.cpu() * args.rec_reward_scale) episode_rec_rewards += final_rec_rewards.cpu().numpy() prev_rec_rewards = rec_rewards overall_reward = (reward * (1 - args.reward_scale) + reward_exploration * args.reward_scale) # Update statistics episode_expl_rewards += reward_exploration.numpy( ) * args.reward_scale # Update rollouts_policy rollouts_policy.insert( obs_im, obs_sm, obs_lm, recurrent_hidden_states, action, action_log_probs, value, overall_reward, masks, obs_collns, ) # Update prev values prev_collision = obs_collns prev_action = action episode_collisions += obs_collns.cpu().numpy() # Update RL policy if (step + 1) % args.num_rl_steps == 0: # Update value function for last step with torch.no_grad(): encoder_inputs = [rollouts_policy.obs_im[-1]] if args.encoder_type == "rgb+map": encoder_inputs.append(rollouts_policy.obs_sm[-1]) encoder_inputs.append(rollouts_policy.obs_lm[-1]) obs_feats = encoder(*encoder_inputs) policy_inputs = {"features": obs_feats} if args.use_action_embedding: policy_inputs["actions"] = prev_action.long() if args.use_collision_embedding: policy_inputs["collisions"] = prev_collision.long() next_value = actor_critic.get_value( policy_inputs, rollouts_policy.recurrent_hidden_states[-1], rollouts_policy.masks[-1], ).detach() # Compute returns rollouts_policy.compute_returns(next_value, args.use_gae, args.gamma, args.tau) # Update model rl_losses = rl_agent.update(rollouts_policy) # Refresh rollouts_policy rollouts_policy.after_update() # =================== Save model ==================== if (j + 1) % args.save_interval == 0 and args.save_dir != "": save_path = f"{args.save_dir}/checkpoints" try: os.makedirs(save_path) except OSError: pass encoder_state = encoder.state_dict() actor_critic_state = actor_critic.state_dict() torch.save( [encoder_state, actor_critic_state, j], f"{save_path}/ckpt.latest.pth", ) if args.save_unique: torch.save( [encoder_state, actor_critic_state, j], f"{save_path}/ckpt.{(j+1):07d}.pth", ) # =================== Logging data ==================== total_num_steps = (j + 1 - j_start) * NPROC * args.num_steps if j % args.log_interval == 0: end = time.time() fps = int(total_num_steps / (end - start)) print(f"===> Updates {j}, #steps {total_num_steps}, FPS {fps}") train_metrics = rl_losses train_metrics["exploration_rewards"] = ( np.mean(episode_expl_rewards) * rec_reward_interval / args.num_steps) train_metrics["rec_rewards"] = (np.mean(episode_rec_rewards) * rec_reward_interval / args.num_steps) train_metrics["area_covered"] = np.mean(per_proc_area) train_metrics["objects_covered"] = np.mean(objects_tracker) train_metrics["landmarks_covered"] = np.mean(osr_tracker) train_metrics["collisions"] = np.mean(episode_collisions) train_metrics["novelty_rewards"] = np.mean(novelty_tracker) train_metrics["smooth_coverage_rewards"] = np.mean( smooth_coverage_tracker) for k, v in train_metrics.items(): print(f"{k}: {v:.3f}") tbwriter.add_scalar(f"train_metrics/{k}", v, j) # =================== Evaluate models ==================== if args.eval_interval is not None and (j + 1) % args.eval_interval == 0: if "habitat" in args.env_name: devices = [ int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",") ] # Devices need to be indexed between 0 to N-1 devices = [dev for dev in range(len(devices))] eval_envs = make_vec_envs_habitat( args.eval_habitat_config_file, device, devices) else: eval_envs = make_vec_envs_avd( args.env_name, args.seed + 12, 12, eval_log_dir, device, True, split="val", nRef=NREF, set_return_topdown_map=True, ) num_eval_episodes = 16 if "habitat" in args.env_name else 30 eval_config = {} eval_config["num_steps"] = args.num_steps eval_config["feat_shape_sim"] = args.feat_shape_sim eval_config[ "num_processes"] = 1 if "habitat" in args.env_name else 12 eval_config["odometer_shape"] = args.odometer_shape eval_config["num_eval_episodes"] = num_eval_episodes eval_config["num_pose_refs"] = NREF eval_config["env_name"] = args.env_name eval_config["actor_type"] = "learned_policy" eval_config["encoder_type"] = args.encoder_type eval_config["use_action_embedding"] = args.use_action_embedding eval_config[ "use_collision_embedding"] = args.use_collision_embedding eval_config["cluster_centroids"] = cluster_centroids eval_config["clusters2images"] = clusters2images eval_config["rec_loss_fn"] = rec_loss_fn_classify eval_config[ "vis_save_dir"] = f"{args.save_dir}/policy_vis/update_{(j+1):05d}" models = {} models["decoder"] = decoder models["pose_encoder"] = pose_encoder models["feature_network"] = feature_network models["encoder"] = encoder models["actor_critic"] = actor_critic val_metrics, _ = evaluate_reconstruction(models, eval_envs, eval_config, device) decoder.eval() pose_encoder.eval() feature_network.eval() actor_critic.train() encoder.train() for k, v in val_metrics.items(): tbwriter.add_scalar(f"val_metrics/{k}", v, j) tbwriter.close()
def compute_pose_sptm_ransac( obs_feats_sim, obs_feats_pose, obs_odometer, ref_feats_sim, ref_feats_pose, config, models, device, env_name, ): """ Given a history of observations in the form of features, odometer readings, estimate the location of a set of references given their features. Inputs: obs_feats_sim - (T, N, feat_size_sim) obs_feats_pose - (T, N, feat_size_pose) obs_odometer - (T, N, 4) ref_feats_sim - (N, nRef, feat_size_sim) ref_feats_pose - (N, nRef, feat_size_pose) Outputs: predicted_poses - (N, nRef, num_poses) predicted_positions - (N, nRef, 3) novote_masks - (N, nRef) all_paired_poses_map - (T, N, nRef, 3) obs_poses - (T, N, 3) inlier_mask - (T, N, nRef) all_pairwise_scores - (T, N, nRef) voting_maps - (N, nRef, 1, mh, mw) """ map_shape = config["map_shape"] map_scale = config["map_scale"] bin_size = config["bin_size"] angles = config["angles"] median_filter_size = config["median_filter_size"] match_thresh = config["match_thresh"] rnet = models["rnet"] posenet = models["posenet"] # Pairwise pose predictor pose_head = models["pose_head"] ransac_estimator = models["ransac_estimator"] # ========== Compute features for similarity prediction =========== T, N = obs_feats_sim.shape[:2] nRef = ref_feats_sim.shape[1] feat_size_sim = obs_feats_sim.shape[2] # ========== Compute the positions of each observation on the map ============ # obs_odometer ---> (T, N, 4) ---> (y, x, phi_head, phi_elev) obs_poses = torch.index_select( obs_odometer, 2, torch.LongTensor([1, 0, 2]).to(device)) # (T, N, 3) ---> (x, y, phi_head) # ========== Compute pairwise similarity with all prior observations ============ median_filter = MedianPool1d(median_filter_size, 1, median_filter_size // 2) median_filter.to(device) obs_feats_sim = repeat(obs_feats_sim, "t n f -> (t n r) f", r=nRef) ref_feats_sim = repeat(ref_feats_sim, "n r f -> (t n r) f", t=T) with torch.no_grad(): paired_scores = rnet.compare( torch.cat([obs_feats_sim, ref_feats_sim], dim=1)) paired_scores = F.softmax(paired_scores, dim=1)[:, 1] # (T*N*nRef, ) paired_scores = rearrange(paired_scores, "(t n r) -> t (n r)", t=T, n=N) if paired_scores.shape[0] > 1: # Apply median filtering paired_scores = rearrange(paired_scores, "t nr -> nr () t") paired_scores = median_filter(paired_scores) paired_scores = rearrange(paired_scores, "nr () t -> t nr") paired_scores = rearrange(paired_scores, "t (n r) -> t n r", t=T, n=N) # ========== Compute pairwise poses with all prior observations ============ ref_feats_pose = repeat(ref_feats_pose, "n r f -> (t n r) f", t=T) obs_feats_pose = repeat(obs_feats_pose, "t n f -> (t n r) f", r=nRef) with torch.no_grad(): pairwise_dposes = posenet.get_pose_xyt_feats(obs_feats_pose, ref_feats_pose) if "avd" in env_name: pairwise_dposes[:, :2] *= 1000.0 # (m -> mm) # ============= Add pairwise delta to observation pose ============== obs_poses_rep = repeat(obs_poses, "t n p -> (t n r) p", r=nRef) pairwise_poses_world = add_pose(obs_poses_rep, pairwise_dposes, mode="xyt") pairwise_poses_world = pairwise_poses_world.view(T, N, nRef, 3) # (x, y, t) # ========== Define similarity weighted sampling function ========== paired_scores = rearrange(paired_scores, "t n r -> t (n r)") # When no samples fall above the match_thresh, set match_thresh to a lower value match_thresh_mask = (paired_scores > match_thresh).sum( dim=0) == 0 # (N*nRef, ) batch_match_thresh = torch.ones(N * nRef).to(device) * match_thresh # If any element has zero samples above match threshold if match_thresh_mask.sum().item() > 0: batch_match_thresh[match_thresh_mask] = ( paired_scores[:, match_thresh_mask].max(dim=0)[0] - 0.001 ) # (N*nRef, ) batch_match_thresh = batch_match_thresh.unsqueeze(0) # (1, N*nRef) # Compute mask indicating validity of samples along time valid_masks = paired_scores > batch_match_thresh # (T, N*nRef) # Assign zero weights to observations below a matching threshold sample_weights = (paired_scores * (paired_scores > batch_match_thresh).float() ) # (T, N*nRef) pairwise_poses_world = rearrange(pairwise_poses_world, "t n r p -> t (n r) p") ( pred_pose_inliers, pred_position_inliers, voting_map_inliers, inlier_mask, ) = ransac_estimator.ransac_pose_estimation(pairwise_poses_world, sample_weights, valid_masks) novote_masks = match_thresh_mask # pred_pose_inliers - (N*nRef, num_poses), pred_position_inliers - (N*nRef, 3) # novote_masks - (N*nRef, ), voting_map_inliers - (N*nRef, 1, mh, mw) pairwise_poses_map = ransac_estimator.polar2map( xyt2polar(flatten_two(pairwise_poses_world))) # (T*N*nRef, 3) pred_pose_inliers = unflatten_two(pred_pose_inliers, N, nRef) # (N, nRef, num_poses) pred_position_inliers = unflatten_two(pred_position_inliers, N, nRef) # (N, nRef, 3) pairwise_poses_map = pairwise_poses_map.view(T, N, nRef, 3) # (T, N, nRef, 3) obs_poses = obs_poses # (T, N, 3) novote_masks = novote_masks.view(N, nRef) # (N, nRef) inlier_mask = inlier_mask.view(T, N, nRef) # (T, N, nRef) pairwise_scores = paired_scores.view(T, N, nRef) # (T, N, nRef) final_voting_maps = unflatten_two(voting_map_inliers, N, nRef) # (N, nRef, 1, mh, mw) outputs = {} outputs["predicted_poses"] = pred_pose_inliers outputs["predicted_positions"] = pred_position_inliers outputs["novote_masks"] = novote_masks outputs["all_paired_poses_map"] = pairwise_poses_map outputs["obs_poses"] = obs_poses outputs["inlier_mask"] = inlier_mask outputs["all_pairwise_scores"] = pairwise_scores outputs["voting_maps"] = final_voting_maps outputs[ "successful_votes"] = inlier_mask # The observations which successfully voted return outputs
def update(self, rollouts): T, N, nfeats = rollouts.obs_feats[:-1].shape nRef = rollouts.tgt_feats.shape[1] device = rollouts.obs_feats.device avg_loss = 0.0 avg_loss_count = 0.0 tgt_feats = rollouts.tgt_feats # (N, nRef, nfeats) tgt_masks = rollouts.tgt_masks.squeeze(2) # (N, nRef) obs_feats = unsq_exp(rollouts.obs_feats, nRef, dim=2) # (T+1, N, nRef, nfeats) obs_poses = unsq_exp(rollouts.obs_odometer[:, :, :3], nRef, dim=2) # (T+1, N, nRef, 3) - (y, x, phi) tgt_poses = unsq_exp(rollouts.tgt_poses, T + 1, dim=0) # (T+1, N, nRef, 3) # Make a prediction after every prediction_interval steps, i.e., # the agent has seen self.prediction_interval*(i+1) observations. for i in range(0, T, self.prediction_interval): L = min(i + self.prediction_interval, T) # Estimate relative pose b/w targets and observations. obs_relpose = subtract_pose( rearrange(tgt_poses[:L], "l b n f -> (l b n) f"), rearrange(obs_poses[:L], "l b n f -> (l b n) f"), ) # (L*N*nRef, 3) --- (x, y, phi) # ========================= Forward pass ========================== # Encode the poses of the observations and targets. obs_relpose_enc = self.pose_encoder(obs_relpose) # (L*N*nRef, 16) obs_relpose_enc = obs_relpose_enc.view(L, N * nRef, -1) tgt_relpose_enc = torch.zeros( 1, *obs_relpose_enc.shape[1:]).to(device) obs_feats_i = rearrange(obs_feats[:L], "l b n f -> l (b n) f") # These serve as inputs to an encoder-decoder transformer model. rec_inputs = { # encoder inputs "history_image_features": obs_feats_i, # (L, N*nRef, nfeats) "history_pose_features": obs_relpose_enc, # (L, N*nRef, 16) # decoder inputs "target_pose_features": tgt_relpose_enc, # (1, N*nRef, 16) } pred_logits = self.decoder(rec_inputs).squeeze( 0) # (N*nRef, nclass) # =================== Compute reconstruction loss ================= loss = self.rec_loss_fn( pred_logits, # (N*nRef, nclass) flatten_two(tgt_feats), # (N*nRef, nfeats) self.cluster_centroids, K=self.rec_loss_fn_J, reduction="none", ).sum(dim=1) # (N*nRef, ) loss = unflatten_two(loss, N, nRef) # Mask out invalid targets. loss = loss * tgt_masks loss = loss.mean() # ========================== Backward pass ======================== self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_( itertools.chain( self.decoder.parameters(), self.pose_encoder.parameters(), ), self.max_grad_norm, ) self.optimizer.step() avg_loss += loss.item() avg_loss_count += 1.0 avg_loss = avg_loss / avg_loss_count losses = {"rec_loss": avg_loss} return losses