def main(): torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") ndevices = torch.cuda.device_count() # Setup loggers tbwriter = SummaryWriter(log_dir=args.log_dir) logging.basicConfig(filename=f"{args.log_dir}/train_log.txt", level=logging.DEBUG) logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.getLogger().setLevel(logging.INFO) if "habitat" in args.env_name: devices = [ int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",") ] # Devices need to be indexed between 0 to N-1 devices = [dev for dev in range(len(devices))] if len(devices) > 2: devices = devices[1:] envs = make_vec_envs_habitat(args.habitat_config_file, device, devices, seed=args.seed) else: train_log_dir = os.path.join(args.log_dir, "train_monitor") try: os.makedirs(train_log_dir) except OSError: pass envs = make_vec_envs_avd( args.env_name, args.seed, args.num_processes, train_log_dir, device, True, num_frame_stack=1, split="train", nRef=args.num_pose_refs, ) args.feat_shape_sim = (512, ) args.feat_shape_pose = (512 * 9, ) args.obs_shape = envs.observation_space.spaces["im"].shape # =================== Create models ==================== if args.encoder_type == "rgb": encoder = RGBEncoder(fix_cnn=args.fix_cnn) elif args.encoder_type == "rgb+map": encoder = MapRGBEncoder(fix_cnn=args.fix_cnn) else: raise ValueError(f"encoder_type {args.encoder_type} not defined!") action_config = ({ "nactions": envs.action_space.n, "embedding_size": args.action_embedding_size } if args.use_action_embedding else None) collision_config = ({ "collision_dim": 2, "embedding_size": args.collision_embedding_size } if args.use_collision_embedding else None) actor_critic = Policy( envs.action_space, base_kwargs={ "feat_dim": args.feat_shape_sim[0], "recurrent": True, "hidden_size": args.feat_shape_sim[0], "action_config": action_config, "collision_config": collision_config, }, ) icm_phi = Phi() if args.icm_embedding_type == "imagenet" else None icm_fd = ForwardDynamics(envs.action_space.n) # =================== Load models ==================== save_path = os.path.join(args.save_dir, "checkpoints") checkpoint_path = os.path.join(save_path, "ckpt.latest.pth") if os.path.isfile(checkpoint_path): logging.info("Resuming from old model!") loaded_states = torch.load(checkpoint_path) encoder_state, actor_critic_state, icm_fd_state, j_start = loaded_states encoder.load_state_dict(encoder_state) actor_critic.load_state_dict(actor_critic_state) icm_fd.load_state_dict(icm_fd_state) elif args.pretrained_il_model != "": logging.info("Initializing with pre-trained model!") encoder_state, actor_critic_state, _ = torch.load( args.pretrained_il_model) encoder.load_state_dict(encoder_state) actor_critic.load_state_dict(actor_critic_state) j_start = -1 else: j_start = -1 encoder.to(device) actor_critic.to(device) if args.icm_embedding_type == "imagenet": icm_phi.to(device) icm_fd.to(device) encoder.eval() actor_critic.eval() if args.icm_embedding_type == "imagenet": icm_phi.eval() # Do not train/the feature model for ICM icm_fd.eval() # =================== Define ICM training algorithm ==================== icm_optimizer = optim.Adam(icm_fd.parameters(), lr=args.lr) # Maintain a running mean of the variance of returns after every # num-rl-steps if args.normalize_icm_rewards: args.returns_rms = RunningMeanStd() # =================== Define RL training algorithm ==================== rl_algo_config = {} rl_algo_config["lr"] = args.lr rl_algo_config["eps"] = args.eps rl_algo_config["encoder_type"] = args.encoder_type rl_algo_config["max_grad_norm"] = args.max_grad_norm rl_algo_config["clip_param"] = args.clip_param rl_algo_config["ppo_epoch"] = args.ppo_epoch rl_algo_config["entropy_coef"] = args.entropy_coef rl_algo_config["num_mini_batch"] = args.num_mini_batch rl_algo_config["value_loss_coef"] = args.value_loss_coef rl_algo_config["use_clipped_value_loss"] = False rl_algo_config["nactions"] = envs.action_space.n rl_algo_config["encoder"] = encoder rl_algo_config["actor_critic"] = actor_critic rl_algo_config["use_action_embedding"] = args.use_action_embedding rl_algo_config["use_collision_embedding"] = args.use_collision_embedding rl_agent = PPO(rl_algo_config) # =================== Define stats buffer ==================== train_metrics_tracker = defaultdict(lambda: deque(maxlen=10)) # =================== Define rollouts ==================== rollouts_policy = RolloutStoragePPO( args.num_rl_steps, args.num_processes, args.obs_shape, envs.action_space, args.feat_shape_sim[0], encoder_type=args.encoder_type, ) rollouts_policy.to(device) def get_obs(obs): obs_im = process_image(obs["im"]) if args.encoder_type == "rgb+map": obs_lm = process_image(obs["coarse_occupancy"]) obs_sm = process_image(obs["fine_occupancy"]) else: obs_lm = None obs_sm = None return obs_im, obs_sm, obs_lm start = time.time() NPROC = args.num_processes for j in range(j_start + 1, num_updates): # =================== Start a new episode ==================== obs = envs.reset() # Reset ICM data buffer all_icm_feats = [] all_icm_acts = [] # Set icm models to evaluate mode for data gathering if args.icm_embedding_type == "imagenet": icm_phi.eval() icm_fd.eval() # Processing environment inputs obs_im, obs_sm, obs_lm = get_obs(obs) # (num_processes, 3, 84, 84) obs_collns = obs["collisions"].long() # (num_processes, 1) # Initialize the memory of rollouts for policy rollouts_policy.reset() rollouts_policy.obs_im[0].copy_(obs_im) if args.encoder_type == "rgb+map": rollouts_policy.obs_sm[0].copy_(obs_sm) rollouts_policy.obs_lm[0].copy_(obs_lm) rollouts_policy.collisions[0].copy_(obs_collns) # Episode statistics episode_expl_rewards = np.zeros((NPROC, 1)) episode_collisions = np.zeros((NPROC, 1)) episode_collisions += obs_collns.cpu().numpy() # Metrics osr_tracker = [0.0 for _ in range(NPROC)] objects_tracker = [0.0 for _ in range(NPROC)] area_tracker = [0.0 for _ in range(NPROC)] novelty_tracker = [0.0 for _ in range(NPROC)] smooth_coverage_tracker = [0.0 for _ in range(NPROC)] per_proc_area = [0.0 for _ in range(NPROC)] # Other states prev_action = torch.zeros(NPROC, 1).to(device) prev_collision = obs_collns action_onehot = torch.zeros(NPROC, envs.action_space.n).to( device) # (N, n_actions) # ================= Update over a full batch of episodes ================= # num_steps must be total number of steps in each episode for step in range(args.num_steps): pstep = rollouts_policy.step with torch.no_grad(): encoder_inputs = [rollouts_policy.obs_im[pstep]] if args.encoder_type == "rgb+map": encoder_inputs.append(rollouts_policy.obs_sm[pstep]) encoder_inputs.append(rollouts_policy.obs_lm[pstep]) obs_feats = encoder(*encoder_inputs) policy_inputs = {"features": obs_feats} if args.use_action_embedding: policy_inputs["actions"] = prev_action.long() if args.use_collision_embedding: policy_inputs["collisions"] = prev_collision.long() policy_outputs = actor_critic.act( policy_inputs, rollouts_policy.recurrent_hidden_states[pstep], rollouts_policy.masks[pstep], ) ( value, action, action_log_probs, recurrent_hidden_states, ) = policy_outputs # Gather curiosity experience. By default, the features are deatached # from the forward dynamics loss. if args.icm_embedding_type == "imagenet": with torch.no_grad(): icm_feats = icm_phi(obs_im) else: icm_feats = recurrent_hidden_states all_icm_feats.append(icm_feats) all_icm_acts.append(action) # Act, get reward and next obs obs, reward, done, infos = envs.step(action) # Processing environment inputs obs_im, obs_sm, obs_lm = get_obs(obs) # (num_processes, 3, 84, 84) obs_collns = obs["collisions"] # (num_processes, 1) # Always set masks to 1 (since this loop happens within one episode) masks = torch.FloatTensor([[1.0] for _ in range(NPROC)]).to(device) # Compute curiosity rewards for the previous action (not the current) reward_exploration = torch.zeros(NPROC, 1) if step >= 1: phi_st = all_icm_feats[-2] phi_st1 = all_icm_feats[-1] action_onehot.zero_() act = all_icm_acts[-2] action_onehot.scatter_(1, act, 1) with torch.no_grad(): phi_st1_hat = icm_fd(phi_st, action_onehot) reward_exploration = (F.mse_loss( phi_st1_hat, phi_st1, reduction="none").sum(dim=1).unsqueeze(1).detach() ) # (N, 1) # Since this reward corresponds to the previous action, update it # accordingly in the rollouts buffer. rollouts_policy.update_prev_rewards(reward_exploration * args.reward_scale) reward_exploration = reward_exploration.cpu() for proc in range(NPROC): seen_area = float(infos[proc]["seen_area"]) objects_visited = infos[proc].get("num_objects_visited", 0.0) oracle_success = float(infos[proc].get("oracle_pose_success", 0.0)) novelty_reward = infos[proc].get("count_based_reward", 0.0) smooth_coverage_reward = infos[proc].get( "coverage_novelty_reward", 0.0) area_tracker[proc] = seen_area objects_tracker[proc] = objects_visited osr_tracker[proc] = oracle_success per_proc_area[proc] = seen_area novelty_tracker[proc] += novelty_reward smooth_coverage_tracker[proc] += smooth_coverage_reward # Instrinsic reward is updated separately (delayed by 1 time step) overall_reward = reward * (1 - args.reward_scale) # Update statistics episode_expl_rewards += reward_exploration.numpy( ) * args.reward_scale # Update rollouts_policy rollouts_policy.insert( obs_im, obs_sm, obs_lm, recurrent_hidden_states, action, action_log_probs, value, overall_reward, masks, obs_collns, ) # Update prev values prev_collision = obs_collns prev_action = action episode_collisions += obs_collns.cpu().numpy() # Update RL policy if (step + 1) % args.num_rl_steps == 0: # Update value function for last step with torch.no_grad(): encoder_inputs = [rollouts_policy.obs_im[-1]] if args.encoder_type == "rgb+map": encoder_inputs.append(rollouts_policy.obs_sm[-1]) encoder_inputs.append(rollouts_policy.obs_lm[-1]) obs_feats = encoder(*encoder_inputs) policy_inputs = {"features": obs_feats} if args.use_action_embedding: policy_inputs["actions"] = prev_action.long() if args.use_collision_embedding: policy_inputs["collisions"] = prev_collision.long() next_value = actor_critic.get_value( policy_inputs, rollouts_policy.recurrent_hidden_states[-1], rollouts_policy.masks[-1], ).detach() # Normalize the rewards if applicable if args.normalize_icm_rewards: current_returns = 0.0 for rew in torch.flip(rollouts_policy.rewards, dims=[0]): current_returns = current_returns * args.gamma + rew current_returns = current_returns.squeeze(1).cpu().numpy() args.returns_rms.update(current_returns) rollouts_policy.rewards /= args.returns_rms.var.item() # Compute returns rollouts_policy.compute_returns( next_value, args.use_gae, args.gamma, args.tau, ) encoder.train() actor_critic.train() # Update model rl_losses = rl_agent.update(rollouts_policy) # Refresh rollouts rollouts_policy.after_update() encoder.eval() actor_critic.eval() # ============ Update the ICM dynamics model using past data =============== icm_fd.train() action_onehot = torch.zeros(NPROC, envs.action_space.n).to( device) # (N, n_actions) avg_fd_loss = 0 avg_fd_loss_count = 0 icm_update_count = 0 for t in random_range(0, args.num_steps - 1): phi_st = all_icm_feats[t] # (N, 512) phi_st1 = all_icm_feats[t + 1] # (N, 512) action_onehot.zero_() at = all_icm_acts[t].long() # (N, 1) action_onehot.scatter_(1, at, 1) # Forward pass phi_st1_hat = icm_fd(phi_st, action_onehot) fd_loss = F.mse_loss(phi_st1_hat, phi_st1) # Backward pass icm_optimizer.zero_grad() fd_loss.backward() torch.nn.utils.clip_grad_norm_(icm_fd.parameters(), args.max_grad_norm) # Update step icm_optimizer.step() avg_fd_loss += fd_loss.item() avg_fd_loss_count += phi_st1_hat.shape[0] avg_fd_loss /= avg_fd_loss_count all_losses = {"icm_fd_loss": avg_fd_loss} icm_fd.eval() # =================== Save model ==================== if (j + 1) % args.save_interval == 0 and args.save_dir != "": save_path = os.path.join(args.save_dir, "checkpoints") try: os.makedirs(save_path) except OSError: pass encoder_state = encoder.state_dict() actor_critic_state = actor_critic.state_dict() icm_fd_state = icm_fd.state_dict() torch.save( [encoder_state, actor_critic_state, icm_fd_state, j], f"{save_path}/ckpt.latest.pth", ) if args.save_unique: torch.save( [encoder_state, actor_critic_state, icm_fd_state, j], f"{save_path}/ckpt.{(j+1):07d}.pth", ) # =================== Logging data ==================== total_num_steps = (j + 1 - j_start) * NPROC * args.num_steps if j % args.log_interval == 0: end = time.time() fps = int(total_num_steps / (end - start)) logging.info( f"===> Updates {j}, #steps {total_num_steps}, FPS {fps}") train_metrics = rl_losses train_metrics.update(all_losses) train_metrics["exploration_rewards"] = np.mean( episode_expl_rewards) train_metrics["area_covered"] = np.mean(per_proc_area) train_metrics["objects_covered"] = np.mean(objects_tracker) train_metrics["landmarks_covered"] = np.mean(osr_tracker) train_metrics["collisions"] = np.mean(episode_collisions) train_metrics["novelty_rewards"] = np.mean(novelty_tracker) train_metrics["smooth_coverage_rewards"] = np.mean( smooth_coverage_tracker) # Update statistics for k, v in train_metrics.items(): train_metrics_tracker[k].append(v) for k, v in train_metrics_tracker.items(): logging.info(f"{k}: {np.mean(v).item():.3f}") tbwriter.add_scalar(f"train_metrics/{k}", np.mean(v).item(), j) # =================== Evaluate models ==================== if args.eval_interval is not None and (j + 1) % args.eval_interval == 0: if "habitat" in args.env_name: devices = [ int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",") ] # Devices need to be indexed between 0 to N-1 devices = [dev for dev in range(len(devices))] eval_envs = make_vec_envs_habitat( args.eval_habitat_config_file, device, devices) else: eval_envs = make_vec_envs_avd( args.env_name, args.seed + 12, 12, eval_log_dir, device, True, split="val", nRef=args.num_pose_refs, set_return_topdown_map=True, ) num_eval_episodes = 16 if "habitat" in args.env_name else 30 eval_config = {} eval_config["num_steps"] = args.num_steps eval_config["feat_shape_sim"] = args.feat_shape_sim eval_config[ "num_processes"] = 1 if "habitat" in args.env_name else 12 eval_config["num_pose_refs"] = args.num_pose_refs eval_config["num_eval_episodes"] = num_eval_episodes eval_config["env_name"] = args.env_name eval_config["actor_type"] = "learned_policy" eval_config["encoder_type"] = args.encoder_type eval_config["use_action_embedding"] = args.use_action_embedding eval_config[ "use_collision_embedding"] = args.use_collision_embedding eval_config[ "vis_save_dir"] = f"{args.save_dir}/policy_vis/update_{(j+1):05d}" models = {} models["encoder"] = encoder models["actor_critic"] = actor_critic val_metrics, _ = evaluate_visitation(models, eval_envs, eval_config, device, visualize_policy=False) for k, v in val_metrics.items(): tbwriter.add_scalar(f"val_metrics/{k}", v, j) tbwriter.close()
def main(): torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") ndevices = torch.cuda.device_count() # Setup loggers tbwriter = SummaryWriter(log_dir=args.log_dir) logging.basicConfig(filename=f"{args.log_dir}/train_log.txt", level=logging.DEBUG) logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.getLogger().setLevel(logging.INFO) if "habitat" in args.env_name: devices = [ int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",") ] # Devices need to be indexed between 0 to N-1 devices = [dev for dev in range(len(devices))] envs = make_vec_envs_habitat(args.habitat_config_file, device, devices, seed=args.seed) else: train_log_dir = os.path.join(args.log_dir, "train_monitor") try: os.makedirs(train_log_dir) except OSError: pass envs = make_vec_envs_avd( args.env_name, args.seed, args.num_processes, train_log_dir, device, True, num_frame_stack=1, split="train", nRef=args.num_pose_refs, ) args.feat_shape_sim = (512, ) args.feat_shape_pose = (512 * 9, ) args.obs_shape = envs.observation_space.spaces["im"].shape # =================== Create models ==================== encoder = RGBEncoder() if args.encoder_type == "rgb" else MapRGBEncoder() action_config = ({ "nactions": envs.action_space.n, "embedding_size": args.action_embedding_size } if args.use_action_embedding else None) collision_config = ({ "collision_dim": 2, "embedding_size": args.collision_embedding_size } if args.use_collision_embedding else None) actor_critic = Policy( envs.action_space, base_kwargs={ "feat_dim": args.feat_shape_sim[0], "recurrent": True, "hidden_size": args.feat_shape_sim[0], "action_config": action_config, "collision_config": collision_config, }, ) # =================== Load models ==================== save_path = os.path.join(args.save_dir, "checkpoints") checkpoint_path = os.path.join(save_path, "ckpt.latest.pth") if os.path.isfile(checkpoint_path): logging.info("Resuming from old model!") loaded_states = torch.load(checkpoint_path) encoder_state, actor_critic_state, j_start = loaded_states encoder.load_state_dict(encoder_state) actor_critic.load_state_dict(actor_critic_state) elif args.pretrained_il_model != "": logging.info("Initializing with pre-trained model!") encoder_state, actor_critic_state, _ = torch.load( args.pretrained_il_model) actor_critic.load_state_dict(actor_critic_state) encoder.load_state_dict(encoder_state) j_start = -1 else: j_start = -1 encoder.to(device) actor_critic.to(device) encoder.train() actor_critic.train() # =================== Define RL training algorithm ==================== rl_algo_config = {} rl_algo_config["lr"] = args.lr rl_algo_config["eps"] = args.eps rl_algo_config["encoder_type"] = args.encoder_type rl_algo_config["max_grad_norm"] = args.max_grad_norm rl_algo_config["clip_param"] = args.clip_param rl_algo_config["ppo_epoch"] = args.ppo_epoch rl_algo_config["entropy_coef"] = args.entropy_coef rl_algo_config["num_mini_batch"] = args.num_mini_batch rl_algo_config["value_loss_coef"] = args.value_loss_coef rl_algo_config["use_clipped_value_loss"] = False rl_algo_config["nactions"] = envs.action_space.n rl_algo_config["encoder"] = encoder rl_algo_config["actor_critic"] = actor_critic rl_algo_config["use_action_embedding"] = args.use_action_embedding rl_algo_config["use_collision_embedding"] = args.use_collision_embedding rl_agent = PPO(rl_algo_config) # =================== Define rollouts ==================== rollouts_policy = RolloutStoragePPO( args.num_rl_steps, args.num_processes, args.obs_shape, envs.action_space, args.feat_shape_sim[0], encoder_type=args.encoder_type, ) rollouts_policy.to(device) def get_obs(obs): obs_im = process_image(obs["im"]) if args.encoder_type == "rgb+map": obs_lm = process_image(obs["coarse_occupancy"]) obs_sm = process_image(obs["fine_occupancy"]) else: obs_lm = None obs_sm = None return obs_im, obs_sm, obs_lm start = time.time() NPROC = args.num_processes for j in range(j_start + 1, num_updates): # =================== Start a new episode ==================== obs = envs.reset() # Processing environment inputs obs_im, obs_sm, obs_lm = get_obs(obs) # (num_processes, 3, 84, 84) obs_collns = obs["collisions"].long() # (num_processes, 1) # Initialize the memory of rollouts for policy rollouts_policy.reset() rollouts_policy.obs_im[0].copy_(obs_im) if args.encoder_type == "rgb+map": rollouts_policy.obs_sm[0].copy_(obs_sm) rollouts_policy.obs_lm[0].copy_(obs_lm) rollouts_policy.collisions[0].copy_(obs_collns) # Episode statistics episode_expl_rewards = np.zeros((NPROC, 1)) episode_collisions = np.zeros((NPROC, 1)) episode_collisions += obs_collns.cpu().numpy() # Metrics osr_tracker = [0.0 for _ in range(NPROC)] objects_tracker = [0.0 for _ in range(NPROC)] area_tracker = [0.0 for _ in range(NPROC)] novelty_tracker = [0.0 for _ in range(NPROC)] smooth_coverage_tracker = [0.0 for _ in range(NPROC)] per_proc_area = [0.0 for _ in range(NPROC)] # Other states prev_action = torch.zeros(NPROC, 1).long().to(device) prev_collision = rollouts_policy.collisions[0] # ================= Update over a full batch of episodes ================= # num_steps must be total number of steps in each episode for step in range(args.num_steps): pstep = rollouts_policy.step with torch.no_grad(): encoder_inputs = [rollouts_policy.obs_im[pstep]] if args.encoder_type == "rgb+map": encoder_inputs.append(rollouts_policy.obs_sm[pstep]) encoder_inputs.append(rollouts_policy.obs_lm[pstep]) obs_feats = encoder(*encoder_inputs) policy_inputs = {"features": obs_feats} if args.use_action_embedding: policy_inputs["actions"] = prev_action.long() if args.use_collision_embedding: policy_inputs["collisions"] = prev_collision.long() policy_outputs = actor_critic.act( policy_inputs, rollouts_policy.recurrent_hidden_states[pstep], rollouts_policy.masks[pstep], ) ( value, action, action_log_probs, recurrent_hidden_states, ) = policy_outputs # Act, get reward and next obs obs, reward, done, infos = envs.step(action) # Processing environment inputs obs_im, obs_sm, obs_lm = get_obs(obs) # (num_processes, 3, 84, 84) obs_collns = obs["collisions"] # (N, 1) # Always set masks to 1 (since this loop happens within one episode) masks = torch.FloatTensor([[1.0] for _ in range(NPROC)]).to(device) # Compute the exploration rewards reward_exploration = torch.zeros(NPROC, 1) # (N, 1) for proc in range(NPROC): seen_area = float(infos[proc]["seen_area"]) objects_visited = infos[proc].get("num_objects_visited", 0.0) oracle_success = float(infos[proc].get("oracle_pose_success", 0.0)) novelty_reward = infos[proc].get("count_based_reward", 0.0) smooth_coverage_reward = infos[proc].get( "coverage_novelty_reward", 0.0) area_reward = seen_area - area_tracker[proc] objects_reward = objects_visited - objects_tracker[proc] landmarks_reward = oracle_success - osr_tracker[proc] collision_reward = -obs_collns[proc, 0].item() reward_exploration[proc] += ( args.area_reward_scale * area_reward + args.objects_reward_scale * objects_reward + args.landmarks_reward_scale * landmarks_reward + args.novelty_reward_scale * novelty_reward + args.collision_penalty_factor * collision_reward + args.smooth_coverage_reward_scale * smooth_coverage_reward) area_tracker[proc] = seen_area objects_tracker[proc] = objects_visited osr_tracker[proc] = oracle_success per_proc_area[proc] = seen_area novelty_tracker[proc] += novelty_reward smooth_coverage_tracker[proc] += smooth_coverage_reward overall_reward = (reward * (1 - args.reward_scale) + reward_exploration * args.reward_scale) # Update statistics episode_expl_rewards += reward_exploration.numpy( ) * args.reward_scale # Update rollouts_policy rollouts_policy.insert( obs_im, obs_sm, obs_lm, recurrent_hidden_states, action, action_log_probs, value, overall_reward, masks, obs_collns, ) # Update prev values prev_collision = obs_collns prev_action = action episode_collisions += obs_collns.cpu().numpy() # Update RL policy if (step + 1) % args.num_rl_steps == 0: # Update value function for last step with torch.no_grad(): encoder_inputs = [rollouts_policy.obs_im[-1]] if args.encoder_type == "rgb+map": encoder_inputs.append(rollouts_policy.obs_sm[-1]) encoder_inputs.append(rollouts_policy.obs_lm[-1]) obs_feats = encoder(*encoder_inputs) policy_inputs = {"features": obs_feats} if args.use_action_embedding: policy_inputs["actions"] = prev_action.long() if args.use_collision_embedding: policy_inputs["collisions"] = prev_collision.long() next_value = actor_critic.get_value( policy_inputs, rollouts_policy.recurrent_hidden_states[-1], rollouts_policy.masks[-1], ).detach() # Compute returns rollouts_policy.compute_returns(next_value, args.use_gae, args.gamma, args.tau) # Update model rl_losses = rl_agent.update(rollouts_policy) # Refresh rollouts rollouts_policy.after_update() # =================== Save model ==================== if (j + 1) % args.save_interval == 0 and args.save_dir != "": save_path = os.path.join(args.save_dir, "checkpoints") try: os.makedirs(save_path) except OSError: pass encoder_state = encoder.state_dict() actor_critic_state = actor_critic.state_dict() torch.save( [encoder_state, actor_critic_state, j], os.path.join(save_path, "ckpt.latest.pth"), ) if args.save_unique: torch.save( [encoder_state, actor_critic_state, j], f"{save_path}/ckpt.{(j+1):07d}.pth", ) # =================== Logging data ==================== total_num_steps = (j + 1 - j_start) * NPROC * args.num_steps if j % args.log_interval == 0: end = time.time() fps = int(total_num_steps / (end - start)) logging.info( f"===> Updates {j}, #steps {total_num_steps}, FPS {fps}") train_metrics = rl_losses train_metrics["exploration_rewards"] = np.mean( episode_expl_rewards) train_metrics["area_covered"] = np.mean(per_proc_area) train_metrics["objects_covered"] = np.mean(objects_tracker) train_metrics["landmarks_covered"] = np.mean(osr_tracker) train_metrics["collisions"] = np.mean(episode_collisions) train_metrics["novelty_rewards"] = np.mean(novelty_tracker) train_metrics["smooth_coverage_rewards"] = np.mean( smooth_coverage_tracker) for k, v in train_metrics.items(): logging.info(f"{k}: {v:.3f}") tbwriter.add_scalar(f"train_metrics/{k}", v, j) # =================== Evaluate models ==================== if args.eval_interval is not None and (j + 1) % args.eval_interval == 0: if "habitat" in args.env_name: devices = [ int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",") ] # Devices need to be indexed between 0 to N-1 devices = [dev for dev in range(len(devices))] eval_envs = make_vec_envs_habitat( args.eval_habitat_config_file, device, devices) else: eval_envs = make_vec_envs_avd( args.env_name, args.seed + 12, 12, eval_log_dir, device, True, split="val", nRef=args.num_pose_refs, set_return_topdown_map=True, ) num_eval_episodes = 16 if "habitat" in args.env_name else 30 eval_config = {} eval_config["num_steps"] = args.num_steps eval_config["feat_shape_sim"] = args.feat_shape_sim eval_config[ "num_processes"] = 1 if "habitat" in args.env_name else 12 eval_config["num_pose_refs"] = args.num_pose_refs eval_config["num_eval_episodes"] = num_eval_episodes eval_config["env_name"] = args.env_name eval_config["actor_type"] = "learned_policy" eval_config["encoder_type"] = args.encoder_type eval_config["use_action_embedding"] = args.use_action_embedding eval_config[ "use_collision_embedding"] = args.use_collision_embedding eval_config[ "vis_save_dir"] = f"{args.save_dir}/policy_vis/update_{(j+1):05d}" models = {} models["encoder"] = encoder models["actor_critic"] = actor_critic val_metrics, _ = evaluate_visitation(models, eval_envs, eval_config, device, visualize_policy=False) for k, v in val_metrics.items(): tbwriter.add_scalar(f"val_metrics/{k}", v, j) tbwriter.close()
def main(): torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") ndevices = torch.cuda.device_count() # Setup loggers tbwriter = SummaryWriter(log_dir=args.log_dir) logging.basicConfig(filename=f"{args.log_dir}/train_log.txt", level=logging.DEBUG) logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.getLogger().setLevel(logging.INFO) if "habitat" in args.env_name: devices = [ int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",") ] # Devices need to be indexed between 0 to N-1 devices = [dev for dev in range(len(devices))] envs = make_vec_envs_habitat(args.habitat_config_file, device, devices, seed=args.seed) else: train_log_dir = os.path.join(args.log_dir, "train_monitor") try: os.makedirs(train_log_dir) except OSError: pass envs = make_vec_envs_avd( args.env_name, args.seed, args.num_processes, train_log_dir, device, True, num_frame_stack=1, split="train", nRef=args.num_pose_refs, ) args.feat_shape_sim = (512, ) args.obs_shape = envs.observation_space.spaces["im"].shape args.odometer_shape = (4, ) # (delta_y, delta_x, delta_head, delta_elev) # =================== Load clusters ================= clusters_h5 = h5py.File(args.clusters_path, "r") cluster_centroids = torch.Tensor(np.array( clusters_h5["cluster_centroids"])).to(device) cluster_centroids_t = cluster_centroids.t() args.nclusters = cluster_centroids.shape[0] clusters2images = {} for i in range(args.nclusters): cluster_images = np.array( clusters_h5[f"cluster_{i}/images"]) # (K, C, H, W) torch Tensor cluster_images = rearrange(cluster_images, "k c h w -> k h w c") cluster_images = (cluster_images * 255.0).astype(np.uint8) clusters2images[i] = cluster_images # (K, H, W, C) clusters_h5.close() # =================== Create models ==================== decoder = FeatureReconstructionModule( args.nclusters, args.nclusters, nlayers=args.n_transformer_layers, ) feature_network = FeatureNetwork() pose_encoder = PoseEncoder() encoder = RGBEncoder() if args.encoder_type == "rgb" else MapRGBEncoder() action_config = ({ "nactions": envs.action_space.n, "embedding_size": args.action_embedding_size } if args.use_action_embedding else None) collision_config = ({ "collision_dim": 2, "embedding_size": args.collision_embedding_size } if args.use_collision_embedding else None) actor_critic = Policy( envs.action_space, base_kwargs={ "feat_dim": args.feat_shape_sim[0], "recurrent": True, "hidden_size": args.feat_shape_sim[0], "action_config": action_config, "collision_config": collision_config, }, ) # =================== Load models ==================== decoder_state, pose_encoder_state = torch.load(args.load_path_rec)[:2] # Remove DataParallel related strings new_decoder_state, new_pose_encoder_state = {}, {} for k, v in decoder_state.items(): new_decoder_state[k.replace("module.", "")] = v for k, v in pose_encoder_state.items(): new_pose_encoder_state[k.replace("module.", "")] = v decoder.load_state_dict(new_decoder_state) pose_encoder.load_state_dict(new_pose_encoder_state) decoder = nn.DataParallel(decoder, dim=1) pose_encoder = nn.DataParallel(pose_encoder, dim=0) save_path = os.path.join(args.save_dir, "checkpoints") checkpoint_path = os.path.join(save_path, "ckpt.latest.pth") if os.path.isfile(checkpoint_path): print("Resuming from old model!") loaded_states = torch.load(checkpoint_path) encoder_state, actor_critic_state, j_start = loaded_states encoder.load_state_dict(encoder_state) actor_critic.load_state_dict(actor_critic_state) elif args.pretrained_il_model != "": logging.info("Initializing with pre-trained model!") encoder_state, actor_critic_state, _ = torch.load( args.pretrained_il_model) actor_critic.load_state_dict(actor_critic_state) encoder.load_state_dict(encoder_state) j_start = -1 else: j_start = -1 encoder.to(device) actor_critic.to(device) decoder.to(device) feature_network.to(device) pose_encoder.to(device) encoder.train() actor_critic.train() # decoder, feature_network, pose_encoder are frozen during policy training decoder.eval() feature_network.eval() pose_encoder.eval() # =================== Define RL training algorithm ==================== rl_algo_config = {} rl_algo_config["lr"] = args.lr rl_algo_config["eps"] = args.eps rl_algo_config["encoder_type"] = args.encoder_type rl_algo_config["max_grad_norm"] = args.max_grad_norm rl_algo_config["clip_param"] = args.clip_param rl_algo_config["ppo_epoch"] = args.ppo_epoch rl_algo_config["entropy_coef"] = args.entropy_coef rl_algo_config["num_mini_batch"] = args.num_mini_batch rl_algo_config["value_loss_coef"] = args.value_loss_coef rl_algo_config["use_clipped_value_loss"] = False rl_algo_config["nactions"] = envs.action_space.n rl_algo_config["encoder"] = encoder rl_algo_config["actor_critic"] = actor_critic rl_algo_config["use_action_embedding"] = args.use_action_embedding rl_algo_config["use_collision_embedding"] = args.use_collision_embedding rl_agent = PPO(rl_algo_config) # =================== Define rollouts ==================== rollouts_recon = RolloutStorageReconstruction( args.num_steps, args.num_processes, (args.nclusters, ), args.odometer_shape, args.num_pose_refs, ) rollouts_policy = RolloutStoragePPO( args.num_rl_steps, args.num_processes, args.obs_shape, envs.action_space, args.feat_shape_sim[0], encoder_type=args.encoder_type, ) rollouts_recon.to(device) rollouts_policy.to(device) def get_obs(obs): obs_im = process_image(obs["im"]) if args.encoder_type == "rgb+map": obs_lm = process_image(obs["coarse_occupancy"]) obs_sm = process_image(obs["fine_occupancy"]) else: obs_lm = None obs_sm = None return obs_im, obs_sm, obs_lm start = time.time() NPROC = args.num_processes NREF = args.num_pose_refs for j in range(j_start + 1, num_updates): # =================== Start a new episode ==================== obs = envs.reset() # Processing environment inputs obs_im, obs_sm, obs_lm = get_obs(obs) # (num_processes, 3, 84, 84) obs_odometer = process_odometer(obs["delta"]) # (num_processes, 4) # Convert mm to m for AVD if "avd" in args.env_name: obs_odometer[:, :2] /= 1000.0 obs_collns = obs["collisions"].long() # (num_processes, 1) # ============== Target poses and corresponding images ================ # NOTE - these are constant throughout the episode. # (num_processes * num_pose_refs, 3) --- (y, x, t) tgt_poses = process_odometer(flatten_two(obs["pose_regress"]))[:, :3] tgt_poses = unflatten_two(tgt_poses, NPROC, NREF) # (N, nRef, 3) tgt_masks = obs["valid_masks"].unsqueeze(2) # (N, nRef, 1) # Convert mm to m for AVD if "avd" in args.env_name: tgt_poses[:, :, :2] /= 1000.0 tgt_ims = process_image(flatten_two( obs["pose_refs"])) # (N*nRef, C, H, W) # Initialize the memory of rollouts for reconstruction rollouts_recon.reset() with torch.no_grad(): obs_feat = feature_network(obs_im) # (N, 2048) tgt_feat = feature_network(tgt_ims) # (N*nRef, 2048) # Compute similarity scores with all other clusters obs_feat = torch.matmul(obs_feat, cluster_centroids_t) # (N, nclusters) tgt_feat = torch.matmul(tgt_feat, cluster_centroids_t) # (N*nRef, nclusters) tgt_feat = unflatten_two(tgt_feat, NPROC, NREF) # (N, nRef, nclusters) rollouts_recon.obs_feats[0].copy_(obs_feat) rollouts_recon.obs_odometer[0].copy_(obs_odometer) rollouts_recon.tgt_poses.copy_(tgt_poses) rollouts_recon.tgt_feats.copy_(tgt_feat) rollouts_recon.tgt_masks.copy_(tgt_masks) # Initialize the memory of rollouts for policy rollouts_policy.reset() rollouts_policy.obs_im[0].copy_(obs_im) if args.encoder_type == "rgb+map": rollouts_policy.obs_sm[0].copy_(obs_sm) rollouts_policy.obs_lm[0].copy_(obs_lm) rollouts_policy.collisions[0].copy_(obs_collns) # Episode statistics episode_expl_rewards = np.zeros((NPROC, 1)) episode_collisions = np.zeros((NPROC, 1)) episode_rec_rewards = np.zeros((NPROC, 1)) episode_collisions += obs_collns.cpu().numpy() # Metrics osr_tracker = [0.0 for _ in range(NPROC)] objects_tracker = [0.0 for _ in range(NPROC)] area_tracker = [0.0 for _ in range(NPROC)] novelty_tracker = [0.0 for _ in range(NPROC)] smooth_coverage_tracker = [0.0 for _ in range(NPROC)] per_proc_area = [0.0 for _ in range(NPROC)] # Other states prev_action = torch.zeros(NPROC, 1).long().to(device) prev_collision = rollouts_policy.collisions[0] rec_reward_interval = args.rec_reward_interval prev_rec_rewards = torch.zeros(NPROC, 1) # (N, 1) prev_rec_rewards = prev_rec_rewards.to(device) rec_rewards_at_t0 = None # ================= Update over a full batch of episodes ================= # num_steps must be total number of steps in each episode for step in range(args.num_steps): pstep = rollouts_policy.step with torch.no_grad(): encoder_inputs = [rollouts_policy.obs_im[pstep]] if args.encoder_type == "rgb+map": encoder_inputs.append(rollouts_policy.obs_sm[pstep]) encoder_inputs.append(rollouts_policy.obs_lm[pstep]) obs_feats = encoder(*encoder_inputs) policy_inputs = {"features": obs_feats} if args.use_action_embedding: policy_inputs["actions"] = prev_action.long() if args.use_collision_embedding: policy_inputs["collisions"] = prev_collision.long() policy_outputs = actor_critic.act( policy_inputs, rollouts_policy.recurrent_hidden_states[pstep], rollouts_policy.masks[pstep], ) ( value, action, action_log_probs, recurrent_hidden_states, ) = policy_outputs # Act, get reward and next obs obs, reward, done, infos = envs.step(action) # Processing environment inputs obs_im, obs_sm, obs_lm = get_obs(obs) # (num_processes, 3, 84, 84) obs_odometer = process_odometer(obs["delta"]) # (num_processes, 4) if "avd" in args.env_name: obs_odometer[:, :2] /= 1000.0 obs_collns = obs["collisions"] # (N, 1) with torch.no_grad(): obs_feat = feature_network(obs_im) # Compute similarity scores with all other clusters obs_feat = torch.matmul(obs_feat, cluster_centroids_t) # (N, nclusters) # Always set masks to 1 (since this loop happens within one episode) masks = torch.FloatTensor([[1.0] for _ in range(NPROC)]).to(device) # Accumulate odometer readings to give relative pose from the starting point obs_odometer = rollouts_recon.obs_odometer[ step] * masks + obs_odometer # Update rollouts_recon rollouts_recon.insert(obs_feat, obs_odometer) # Compute the exploration rewards reward_exploration = torch.zeros(NPROC, 1) # (N, 1) for proc in range(NPROC): seen_area = float(infos[proc]["seen_area"]) objects_visited = infos[proc].get("num_objects_visited", 0.0) oracle_success = float(infos[proc]["oracle_pose_success"]) novelty_reward = infos[proc].get("count_based_reward", 0.0) smooth_coverage_reward = infos[proc].get( "coverage_novelty_reward", 0.0) area_reward = seen_area - area_tracker[proc] objects_reward = objects_visited - objects_tracker[proc] landmarks_reward = oracle_success - osr_tracker[proc] collision_reward = -obs_collns[proc, 0].item() area_tracker[proc] = seen_area objects_tracker[proc] = objects_visited osr_tracker[proc] = oracle_success per_proc_area[proc] = seen_area novelty_tracker[proc] += novelty_reward smooth_coverage_tracker[proc] += smooth_coverage_reward # Compute reconstruction rewards if (step + 1) % rec_reward_interval == 0 or step == 0: rec_rewards = compute_reconstruction_rewards( rollouts_recon.obs_feats[:(step + 1)], rollouts_recon.obs_odometer[:(step + 1), :, :3], rollouts_recon.tgt_feats, rollouts_recon.tgt_poses, cluster_centroids_t, decoder, pose_encoder, ).detach() # (N, nRef) rec_rewards = rec_rewards * tgt_masks.squeeze(2) # (N, nRef) rec_rewards = rec_rewards.sum(dim=1).unsqueeze( 1) # / (tgt_masks.sum(dim=1) + 1e-8) final_rec_rewards = rec_rewards - prev_rec_rewards # if step == 0: # print( # "===============================================================" # ) # Ignore the exploration reward at T=0 since it will be a huge spike if (("avd" in args.env_name) and (step != 0)) or (("habitat" in args.env_name) and (step > 20)): # print( # "Rec rewards[0]: {:.2f}".format(final_rec_rewards[0, 0].item()) # ) reward_exploration += (final_rec_rewards.cpu() * args.rec_reward_scale) episode_rec_rewards += final_rec_rewards.cpu().numpy() prev_rec_rewards = rec_rewards overall_reward = (reward * (1 - args.reward_scale) + reward_exploration * args.reward_scale) # Update statistics episode_expl_rewards += reward_exploration.numpy( ) * args.reward_scale # Update rollouts_policy rollouts_policy.insert( obs_im, obs_sm, obs_lm, recurrent_hidden_states, action, action_log_probs, value, overall_reward, masks, obs_collns, ) # Update prev values prev_collision = obs_collns prev_action = action episode_collisions += obs_collns.cpu().numpy() # Update RL policy if (step + 1) % args.num_rl_steps == 0: # Update value function for last step with torch.no_grad(): encoder_inputs = [rollouts_policy.obs_im[-1]] if args.encoder_type == "rgb+map": encoder_inputs.append(rollouts_policy.obs_sm[-1]) encoder_inputs.append(rollouts_policy.obs_lm[-1]) obs_feats = encoder(*encoder_inputs) policy_inputs = {"features": obs_feats} if args.use_action_embedding: policy_inputs["actions"] = prev_action.long() if args.use_collision_embedding: policy_inputs["collisions"] = prev_collision.long() next_value = actor_critic.get_value( policy_inputs, rollouts_policy.recurrent_hidden_states[-1], rollouts_policy.masks[-1], ).detach() # Compute returns rollouts_policy.compute_returns(next_value, args.use_gae, args.gamma, args.tau) # Update model rl_losses = rl_agent.update(rollouts_policy) # Refresh rollouts_policy rollouts_policy.after_update() # =================== Save model ==================== if (j + 1) % args.save_interval == 0 and args.save_dir != "": save_path = f"{args.save_dir}/checkpoints" try: os.makedirs(save_path) except OSError: pass encoder_state = encoder.state_dict() actor_critic_state = actor_critic.state_dict() torch.save( [encoder_state, actor_critic_state, j], f"{save_path}/ckpt.latest.pth", ) if args.save_unique: torch.save( [encoder_state, actor_critic_state, j], f"{save_path}/ckpt.{(j+1):07d}.pth", ) # =================== Logging data ==================== total_num_steps = (j + 1 - j_start) * NPROC * args.num_steps if j % args.log_interval == 0: end = time.time() fps = int(total_num_steps / (end - start)) print(f"===> Updates {j}, #steps {total_num_steps}, FPS {fps}") train_metrics = rl_losses train_metrics["exploration_rewards"] = ( np.mean(episode_expl_rewards) * rec_reward_interval / args.num_steps) train_metrics["rec_rewards"] = (np.mean(episode_rec_rewards) * rec_reward_interval / args.num_steps) train_metrics["area_covered"] = np.mean(per_proc_area) train_metrics["objects_covered"] = np.mean(objects_tracker) train_metrics["landmarks_covered"] = np.mean(osr_tracker) train_metrics["collisions"] = np.mean(episode_collisions) train_metrics["novelty_rewards"] = np.mean(novelty_tracker) train_metrics["smooth_coverage_rewards"] = np.mean( smooth_coverage_tracker) for k, v in train_metrics.items(): print(f"{k}: {v:.3f}") tbwriter.add_scalar(f"train_metrics/{k}", v, j) # =================== Evaluate models ==================== if args.eval_interval is not None and (j + 1) % args.eval_interval == 0: if "habitat" in args.env_name: devices = [ int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",") ] # Devices need to be indexed between 0 to N-1 devices = [dev for dev in range(len(devices))] eval_envs = make_vec_envs_habitat( args.eval_habitat_config_file, device, devices) else: eval_envs = make_vec_envs_avd( args.env_name, args.seed + 12, 12, eval_log_dir, device, True, split="val", nRef=NREF, set_return_topdown_map=True, ) num_eval_episodes = 16 if "habitat" in args.env_name else 30 eval_config = {} eval_config["num_steps"] = args.num_steps eval_config["feat_shape_sim"] = args.feat_shape_sim eval_config[ "num_processes"] = 1 if "habitat" in args.env_name else 12 eval_config["odometer_shape"] = args.odometer_shape eval_config["num_eval_episodes"] = num_eval_episodes eval_config["num_pose_refs"] = NREF eval_config["env_name"] = args.env_name eval_config["actor_type"] = "learned_policy" eval_config["encoder_type"] = args.encoder_type eval_config["use_action_embedding"] = args.use_action_embedding eval_config[ "use_collision_embedding"] = args.use_collision_embedding eval_config["cluster_centroids"] = cluster_centroids eval_config["clusters2images"] = clusters2images eval_config["rec_loss_fn"] = rec_loss_fn_classify eval_config[ "vis_save_dir"] = f"{args.save_dir}/policy_vis/update_{(j+1):05d}" models = {} models["decoder"] = decoder models["pose_encoder"] = pose_encoder models["feature_network"] = feature_network models["encoder"] = encoder models["actor_critic"] = actor_critic val_metrics, _ = evaluate_reconstruction(models, eval_envs, eval_config, device) decoder.eval() pose_encoder.eval() feature_network.eval() actor_critic.train() encoder.train() for k, v in val_metrics.items(): tbwriter.add_scalar(f"val_metrics/{k}", v, j) tbwriter.close()
def main(): torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") ndevices = torch.cuda.device_count() args.map_shape = (1, args.map_size, args.map_size) # Setup loggers logging.basicConfig(filename=f"{args.log_dir}/eval_log.txt", level=logging.DEBUG) logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.getLogger().setLevel(logging.INFO) args.feat_shape_sim = (512, ) args.feat_shape_pose = (512 * 9, ) args.odometer_shape = (4, ) # (delta_y, delta_x, delta_head, delta_elev) args.match_thresh = 0.95 args.requires_policy = args.actor_type not in [ "random", "oracle", "forward", "forward-plus", "frontier", ] if "habitat" in args.env_name: if "CUDA_VISIBLE_DEVICES" in os.environ: devices = [ int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",") ] # Devices need to be indexed between 0 to N-1 devices = [dev for dev in range(len(devices))] else: devices = None eval_envs = make_vec_envs_habitat( args.habitat_config_file, device, devices, enable_odometry_noise=args.enable_odometry_noise, odometer_noise_scaling=args.odometer_noise_scaling, measure_noise_free_area=args.measure_noise_free_area, ) if args.actor_type == "frontier": large_map_range = 100.0 H = eval_envs.observation_space.spaces[ "highres_coarse_occupancy"].shape[1] args.occ_map_scale = 0.1 * (2 * large_map_range + 1) / H else: eval_envs = make_vec_envs_avd( args.env_name, 123 + args.num_processes, args.num_processes, eval_log_dir, device, True, split=args.eval_split, nRef=args.num_pose_refs, set_return_topdown_map=True, ) if args.actor_type == "frontier": large_map_range = 100.0 H = eval_envs.observation_space.spaces[ "highres_coarse_occupancy"].shape[0] args.occ_map_scale = 50.0 * (2 * large_map_range + 1) / H args.obs_shape = eval_envs.observation_space.spaces["im"].shape args.angles = torch.Tensor(np.radians(np.linspace(180, -150, 12))).to(device) args.bin_size = math.radians(31) # =================== Create models ==================== rnet = RetrievalNetwork() posenet = PairwisePosePredictor(use_classification=args.use_classification, num_classes=args.num_classes) pose_head = ViewLocalizer(args.map_scale) if args.requires_policy: encoder = RGBEncoder( ) if args.encoder_type == "rgb" else MapRGBEncoder() action_config = ({ "nactions": eval_envs.action_space.n, "embedding_size": args.action_embedding_size, } if args.use_action_embedding else None) collision_config = ({ "collision_dim": 2, "embedding_size": args.collision_embedding_size } if args.use_collision_embedding else None) actor_critic = Policy( eval_envs.action_space, base_kwargs={ "feat_dim": args.feat_shape_sim[0], "recurrent": True, "hidden_size": args.feat_shape_sim[0], "action_config": action_config, "collision_config": collision_config, }, ) # =================== Load models ==================== rnet_state = torch.load(args.pretrained_rnet)["state_dict"] rnet.load_state_dict(rnet_state) posenet_state = torch.load(args.pretrained_posenet)["state_dict"] posenet.load_state_dict(posenet_state) rnet.to(device) posenet.to(device) pose_head.to(device) rnet.eval() posenet.eval() pose_head.eval() if args.requires_policy: encoder_state, actor_critic_state = torch.load(args.load_path)[:2] encoder.load_state_dict(encoder_state) actor_critic.load_state_dict(actor_critic_state) actor_critic.to(device) encoder.to(device) actor_critic.eval() encoder.eval() if args.use_multi_gpu: rnet.compare = nn.DataParallel(rnet.compare) rnet.feat_extract = nn.DataParallel(rnet.feat_extract) posenet.compare = nn.DataParallel(posenet.compare) posenet.feat_extract = nn.DataParallel(posenet.feat_extract) posenet.predict_depth = nn.DataParallel(posenet.predict_depth) posenet.predict_baseline = nn.DataParallel(posenet.predict_baseline) posenet.predict_baseline_sign = nn.DataParallel( posenet.predict_baseline_sign) # =================== Define pose criterion ==================== args.pose_loss_fn = get_pose_criterion() lab_shape = get_pose_label_shape() gaussian_kernel = get_gaussian_kernel(kernel_size=args.vote_kernel_size, sigma=0.5, channels=1) eval_config = {} eval_config["num_steps"] = args.num_steps eval_config["num_processes"] = args.num_processes eval_config["obs_shape"] = args.obs_shape eval_config["feat_shape_sim"] = args.feat_shape_sim eval_config["feat_shape_pose"] = args.feat_shape_pose eval_config["odometer_shape"] = args.odometer_shape eval_config["lab_shape"] = lab_shape eval_config["map_shape"] = args.map_shape eval_config["map_scale"] = args.map_scale eval_config["angles"] = args.angles eval_config["bin_size"] = args.bin_size eval_config["gaussian_kernel"] = gaussian_kernel eval_config["match_thresh"] = args.match_thresh eval_config["pose_loss_fn"] = args.pose_loss_fn eval_config["num_eval_episodes"] = args.eval_episodes eval_config["num_pose_refs"] = args.num_pose_refs eval_config["median_filter_size"] = 3 eval_config["vote_kernel_size"] = args.vote_kernel_size eval_config["env_name"] = args.env_name eval_config["actor_type"] = args.actor_type eval_config["pose_predictor_type"] = args.pose_predictor_type eval_config["encoder_type"] = args.encoder_type eval_config["ransac_n"] = args.ransac_n eval_config["ransac_niter"] = args.ransac_niter eval_config["ransac_batch"] = args.ransac_batch eval_config["use_action_embedding"] = args.use_action_embedding eval_config["use_collision_embedding"] = args.use_collision_embedding eval_config["vis_save_dir"] = os.path.join(args.log_dir, "visualizations") eval_config["final_topdown_save_path"] = os.path.join( args.log_dir, "top_down_maps.h5") eval_config["forward_action_id"] = 2 if "avd" in args.env_name else 0 eval_config["turn_action_id"] = 0 if "avd" in args.env_name else 1 eval_config["input_highres"] = args.input_highres if args.actor_type == "frontier": eval_config["occ_map_scale"] = args.occ_map_scale eval_config["frontier_dilate_occ"] = args.frontier_dilate_occ eval_config["max_time_per_target"] = args.max_time_per_target models = {} models["rnet"] = rnet models["posenet"] = posenet models["pose_head"] = pose_head if args.requires_policy: models["actor_critic"] = actor_critic models["encoder"] = encoder metrics, per_episode_metrics = evaluate_pose( models, eval_envs, eval_config, device, multi_step=True, interval_steps=args.interval_steps, visualize_policy=args.visualize_policy, visualize_size=args.visualize_size, visualize_batches=args.visualize_batches, visualize_n_per_batch=args.visualize_n_per_batch, ) json.dump(per_episode_metrics, open(os.path.join(args.log_dir, "statistics.json"), "w"))
def main(): torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") ndevices = torch.cuda.device_count() # Setup loggers logging.basicConfig(filename=f"{args.log_dir}/eval_log.txt", level=logging.DEBUG) logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.getLogger().setLevel(logging.INFO) args.feat_shape_sim = (512, ) args.odometer_shape = (4, ) # (delta_y, delta_x, delta_head, delta_elev) args.requires_policy = args.actor_type not in [ "random", "oracle", "forward", "forward-plus", "frontier", ] if "habitat" in args.env_name: if "CUDA_VISIBLE_DEVICES" in os.environ: devices = [ int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",") ] # Devices need to be indexed between 0 to N-1 devices = [dev for dev in range(len(devices))] else: devices = None eval_envs = make_vec_envs_habitat(args.habitat_config_file, device, devices, seed=args.seed) if args.actor_type == "frontier": large_map_range = 100.0 H = eval_envs.observation_space.spaces[ "highres_coarse_occupancy"].shape[1] args.occ_map_scale = 0.1 * (2 * large_map_range + 1) / H else: eval_envs = make_vec_envs_avd( args.env_name, args.seed + args.num_processes, args.num_processes, eval_log_dir, device, True, split=args.eval_split, nRef=args.num_pose_refs, set_return_topdown_map=True, ) if args.actor_type == "frontier": large_map_range = 100.0 H = eval_envs.observation_space.spaces[ "highres_coarse_occupancy"].shape[0] args.occ_map_scale = 50.0 * (2 * large_map_range + 1) / H args.obs_shape = eval_envs.observation_space.spaces["im"].shape # =================== Load clusters ================= clusters_h5 = h5py.File(args.clusters_path, "r") cluster_centroids = torch.Tensor(np.array( clusters_h5["cluster_centroids"])).to(device) args.nclusters = cluster_centroids.shape[0] clusters2images = {} for i in range(args.nclusters): cluster_images = np.array( clusters_h5[f"cluster_{i}/images"]) # (K, C, H, W) torch Tensor cluster_images = np.ascontiguousarray( cluster_images.transpose(0, 2, 3, 1)) cluster_images = (cluster_images * 255.0).astype(np.uint8) clusters2images[i] = cluster_images # (K, H, W, C) clusters_h5.close() # =================== Create models ==================== decoder = FeatureReconstructionModule( args.nclusters, args.nclusters, nlayers=args.n_transformer_layers, ) feature_network = FeatureNetwork() feature_network = nn.DataParallel(feature_network, dim=0) pose_encoder = PoseEncoder() if args.use_multi_gpu: decoder = nn.DataParallel(decoder, dim=1) pose_encoder = nn.DataParallel(pose_encoder, dim=0) if args.requires_policy: encoder = RGBEncoder( ) if args.encoder_type == "rgb" else MapRGBEncoder() action_config = ({ "nactions": eval_envs.action_space.n, "embedding_size": args.action_embedding_size, } if args.use_action_embedding else None) collision_config = ({ "collision_dim": 2, "embedding_size": args.collision_embedding_size } if args.use_collision_embedding else None) actor_critic = Policy( eval_envs.action_space, base_kwargs={ "feat_dim": args.feat_shape_sim[0], "recurrent": True, "hidden_size": args.feat_shape_sim[0], "action_config": action_config, "collision_config": collision_config, }, ) # =================== Load models ==================== decoder_state, pose_encoder_state = torch.load(args.load_path_rec)[:2] decoder.load_state_dict(decoder_state) pose_encoder.load_state_dict(pose_encoder_state) decoder.to(device) feature_network.to(device) decoder.eval() feature_network.eval() pose_encoder.eval() pose_encoder.to(device) if args.requires_policy: encoder_state, actor_critic_state = torch.load(args.load_path)[:2] encoder.load_state_dict(encoder_state) actor_critic.load_state_dict(actor_critic_state) actor_critic.to(device) encoder.to(device) actor_critic.eval() encoder.eval() eval_config = {} eval_config["num_steps"] = args.num_steps eval_config["num_processes"] = args.num_processes eval_config["feat_shape_sim"] = args.feat_shape_sim eval_config["odometer_shape"] = args.odometer_shape eval_config["num_eval_episodes"] = args.eval_episodes eval_config["num_pose_refs"] = args.num_pose_refs eval_config["env_name"] = args.env_name eval_config["actor_type"] = args.actor_type eval_config["encoder_type"] = args.encoder_type eval_config["use_action_embedding"] = args.use_action_embedding eval_config["use_collision_embedding"] = args.use_collision_embedding eval_config["cluster_centroids"] = cluster_centroids eval_config["clusters2images"] = clusters2images eval_config["rec_loss_fn"] = rec_loss_fn_classify eval_config["vis_save_dir"] = os.path.join(args.log_dir, "visualizations") eval_config["forward_action_id"] = 2 if "avd" in args.env_name else 0 eval_config["turn_action_id"] = 0 if "avd" in args.env_name else 1 if args.actor_type == "frontier": eval_config["occ_map_scale"] = args.occ_map_scale eval_config["frontier_dilate_occ"] = args.frontier_dilate_occ eval_config["max_time_per_target"] = args.max_time_per_target models = {} models["decoder"] = decoder models["pose_encoder"] = pose_encoder models["feature_network"] = feature_network if args.requires_policy: models["actor_critic"] = actor_critic models["encoder"] = encoder metrics, per_episode_metrics = evaluate_reconstruction( models, eval_envs, eval_config, device, multi_step=True, interval_steps=args.interval_steps, visualize_policy=args.visualize_policy, ) json.dump(per_episode_metrics, open(os.path.join(args.log_dir, "statistics.json"), "w"))
def main(): torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") ndevices = torch.cuda.device_count() # Setup loggers tbwriter = SummaryWriter(log_dir=args.log_dir) logging.basicConfig(filename=f"{args.log_dir}/train_log.txt", level=logging.DEBUG) logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.getLogger().setLevel(logging.INFO) if "habitat" in args.env_name: devices = [ int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",") ] # Devices need to be indexed between 0 to N-1 devices = [dev for dev in range(len(devices))] envs = make_vec_envs_habitat(args.habitat_config_file, device, devices, seed=args.seed) else: train_log_dir = os.path.join(args.log_dir, "train_monitor") try: os.makedirs(train_log_dir) except OSError: pass envs = make_vec_envs_avd( args.env_name, args.seed, args.num_processes, train_log_dir, device, True, num_frame_stack=1, split="train", nRef=args.num_pose_refs, ) args.feat_shape_sim = (512, ) args.feat_shape_pose = (512 * 9, ) args.obs_shape = envs.observation_space.spaces["im"].shape args.agent_action_prob = args.agent_start_action_prob # =================== Create models ==================== encoder = RGBEncoder() if args.encoder_type == "rgb" else MapRGBEncoder() action_config = ({ "nactions": envs.action_space.n, "embedding_size": args.action_embedding_size } if args.use_action_embedding else None) collision_config = ({ "collision_dim": 2, "embedding_size": args.collision_embedding_size } if args.use_collision_embedding else None) actor_critic = Policy( envs.action_space, base_kwargs={ "feat_dim": args.feat_shape_sim[0], "recurrent": True, "hidden_size": args.feat_shape_sim[0], "action_config": action_config, "collision_config": collision_config, }, ) # =================== Load models ==================== save_path = os.path.join(args.save_dir, "checkpoints") checkpoint_path = os.path.join(save_path, "ckpt.latest.pth") if os.path.isfile(checkpoint_path): logging.info("Resuming from old model!") loaded_states = torch.load(checkpoint_path) encoder_state, actor_critic_state, j_start = loaded_states encoder.load_state_dict(encoder_state) actor_critic.load_state_dict(actor_critic_state) else: j_start = -1 actor_critic.to(device) encoder.to(device) actor_critic.train() encoder.train() # =================== Define IL training algorithm ==================== il_algo_config = {} il_algo_config["lr"] = args.lr il_algo_config["eps"] = args.eps il_algo_config["max_grad_norm"] = args.max_grad_norm il_algo_config["encoder_type"] = args.encoder_type il_algo_config["nactions"] = envs.action_space.n il_algo_config["encoder"] = encoder il_algo_config["actor_critic"] = actor_critic il_algo_config["use_action_embedding"] = args.use_action_embedding il_algo_config["use_collision_embedding"] = args.use_collision_embedding il_algo_config["use_inflection_weighting"] = args.use_inflection_weighting il_agent = Imitation(il_algo_config) # =================== Define rollouts ==================== rollouts_policy = RolloutStorageImitation( args.num_rl_steps, args.num_processes, args.obs_shape, envs.action_space, args.feat_shape_sim[0], encoder_type=args.encoder_type, ) rollouts_policy.to(device) def get_obs(obs): obs_im = process_image(obs["im"]) if args.encoder_type == "rgb+map": obs_lm = process_image(obs["coarse_occupancy"]) obs_sm = process_image(obs["fine_occupancy"]) else: obs_lm = None obs_sm = None return obs_im, obs_sm, obs_lm start = time.time() NPROC = args.num_processes for j in range(j_start + 1, num_updates): # ======================== Start a new episode ======================== obs = envs.reset() # Processing environment inputs obs_im, obs_sm, obs_lm = get_obs(obs) # (num_processes, 3, 84, 84) obs_collns = obs["collisions"].long() # (num_processes, 1) # Initialize the memory of rollouts for policy rollouts_policy.reset() rollouts_policy.obs_im[0].copy_(obs_im) if args.encoder_type == "rgb+map": rollouts_policy.obs_sm[0].copy_(obs_sm) rollouts_policy.obs_lm[0].copy_(obs_lm) rollouts_policy.collisions[0].copy_(obs_collns) # Episode statistics episode_expl_rewards = np.zeros((NPROC, 1)) episode_collisions = np.zeros((NPROC, 1)) episode_collisions += obs_collns.cpu().numpy() # Metrics per_proc_area = [0.0 for proc in range(NPROC)] # Other states prev_action = torch.zeros(NPROC, 1).long().to(device) prev_collision = rollouts_policy.collisions[0] agent_acting_duration = 0 agent_acting_status = False # =============== Update over a full batch of episodes ================ # num_steps must be total number of steps in each episode for step in range(args.num_steps): pstep = rollouts_policy.step with torch.no_grad(): encoder_inputs = [rollouts_policy.obs_im[pstep]] if args.encoder_type == "rgb+map": encoder_inputs.append(rollouts_policy.obs_sm[pstep]) encoder_inputs.append(rollouts_policy.obs_lm[pstep]) obs_feats = encoder(*encoder_inputs) policy_inputs = {"features": obs_feats} if args.use_action_embedding: policy_inputs["actions"] = prev_action.long() if args.use_collision_embedding: policy_inputs["collisions"] = prev_collision.long() policy_outputs = actor_critic.act( policy_inputs, rollouts_policy.recurrent_hidden_states[pstep], rollouts_policy.masks[pstep], ) ( value, agent_action, action_log_probs, recurrent_hidden_states, ) = policy_outputs oracle_action = obs["oracle_action"].long() # If action mask is active, then take oracle action. # Otherwise, take the agent's action if args.agent_action_duration == 1: action_masks = (torch.cuda.FloatTensor(NPROC, 1).uniform_() >= args.agent_action_prob) else: # agent_action_duration HAS to be atleast 2 to enter this # Agent continues acting if (agent_acting_status and agent_acting_duration > 0 and agent_acting_duration <= args.agent_action_duration): action_masks = torch.zeros(NPROC, 1).to(device) agent_acting_duration = (agent_acting_duration + 1) % args.agent_action_duration # Agent is done acting if agent_acting_duration == 0: agent_acting_status = False # Agent starts acting elif random.random() < args.agent_action_prob: action_masks = torch.zeros(NPROC, 1).to(device) agent_acting_status = True agent_acting_duration += 1 # Agent does not act else: action_masks = torch.ones(NPROC, 1).to(device) action_masks = action_masks.long() action = oracle_action * action_masks + agent_action * ( 1 - action_masks) # Act, get reward and next obs obs, reward, done, infos = envs.step(action) # Processing observations obs_im, obs_sm, obs_lm = get_obs(obs) # (num_processes, 3, 84, 84) obs_collns = obs["collisions"] # (N, 1) # Always set masks to 1 (since this loop happens within one episode) masks = torch.FloatTensor([[1.0] for _ in range(NPROC)]).to(device) # Compute the exploration rewards reward_exploration = torch.zeros(NPROC, 1) # (N, 1) for proc in range(NPROC): reward_exploration[proc] += (float(infos[proc]["seen_area"]) - per_proc_area[proc]) per_proc_area[proc] = float(infos[proc]["seen_area"]) overall_reward = (reward * (1 - args.reward_scale) + reward_exploration * args.reward_scale) # Update statistics episode_expl_rewards += reward_exploration.numpy( ) * args.reward_scale # Update rollouts_policy rollouts_policy.insert( obs_im, obs_sm, obs_lm, recurrent_hidden_states, action, action_log_probs, value, overall_reward, masks, obs_collns, action_masks, ) # Update prev values prev_collision = obs_collns prev_action = action episode_collisions += obs_collns.cpu().numpy() # Update IL policy if (step + 1) % args.num_rl_steps == 0: # Update model il_losses = il_agent.update(rollouts_policy) # Refresh rollouts rollouts_policy.after_update() # =================== Save model ==================== if (j + 1) % args.save_interval == 0 and args.save_dir != "": save_path = os.path.join(args.save_dir, "checkpoints") try: os.makedirs(save_path) except OSError: pass encoder_state = encoder.state_dict() actor_critic_state = actor_critic.state_dict() torch.save( [encoder_state, actor_critic_state, j], os.path.join(save_path, "ckpt.latest.pth"), ) if args.save_unique: torch.save( [encoder_state, actor_critic_state, j], f"{save_path}/ckpt.{(j+1):07d}.pth", ) # =================== Logging data ==================== total_num_steps = (j + 1 - j_start) * NPROC * args.num_steps if j % args.log_interval == 0: end = time.time() fps = int(total_num_steps / (end - start)) logging.info( f"===> Updates {j}, #steps {total_num_steps}, FPS {fps}") train_metrics = il_losses train_metrics["exploration_rewards"] = np.mean( episode_expl_rewards) train_metrics["area_covered"] = np.mean(per_proc_area) train_metrics["collisions"] = np.mean(episode_collisions) for k, v in train_metrics.items(): logging.info(f"{k}: {v:.3f}") tbwriter.add_scalar(f"train_metrics/{k}", v, j) # =================== Evaluate models ==================== if args.eval_interval is not None and (j + 1) % args.eval_interval == 0: if "habitat" in args.env_name: devices = [ int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",") ] # Devices need to be indexed between 0 to N-1 devices = [dev for dev in range(len(devices))] eval_envs = make_vec_envs_habitat( args.eval_habitat_config_file, device, devices) else: eval_envs = make_vec_envs_avd( args.env_name, args.seed + 12, 12, eval_log_dir, device, True, split="val", nRef=args.num_pose_refs, set_return_topdown_map=True, ) num_eval_episodes = 16 if "habitat" in args.env_name else 30 eval_config = {} eval_config["num_steps"] = args.num_steps eval_config["feat_shape_sim"] = args.feat_shape_sim eval_config[ "num_processes"] = 1 if "habitat" in args.env_name else 12 eval_config["num_pose_refs"] = args.num_pose_refs eval_config["num_eval_episodes"] = num_eval_episodes eval_config["env_name"] = args.env_name eval_config["actor_type"] = "learned_policy" eval_config["encoder_type"] = args.encoder_type eval_config["use_action_embedding"] = args.use_action_embedding eval_config[ "use_collision_embedding"] = args.use_collision_embedding eval_config[ "vis_save_dir"] = f"{args.save_dir}/policy_vis/update_{(j+1):05d}" models = {} models["encoder"] = encoder models["actor_critic"] = actor_critic val_metrics, _ = evaluate_visitation(models, eval_envs, eval_config, device, visualize_policy=False) for k, v in val_metrics.items(): tbwriter.add_scalar(f"val_metrics/{k}", v, j) # =========== Update agent action schedule ========== if (j + 1) % args.agent_action_prob_schedule == 0: args.agent_action_prob += args.agent_action_prob_factor args.agent_action_prob = min(args.agent_action_prob, args.agent_end_action_prob) logging.info( f"=======> Updated action sampling schedule to {args.agent_action_prob:.3f}" ) tbwriter.close()