def load_model(): depth_256_space = SpaceDict({ 'depth': spaces.Box(low=0., high=1., shape=(256, 256, 1)), 'pointgoal_with_gps_compass': spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(2, ), dtype=np.float32, ) }) if GAUSSIAN: action_space = spaces.Box(np.array([float('-inf'), float('-inf')]), np.array([float('inf'), float('inf')])) action_distribution = 'gaussian' dim_actions = 2 elif DISCRETE_4: action_space = spaces.Discrete(4) action_distribution = 'categorical' dim_actions = 4 elif DISCRETE_6: action_space = spaces.Discrete(6) action_distribution = 'categorical' dim_actions = 6 model = PointNavResNetPolicy(observation_space=depth_256_space, action_space=action_space, hidden_size=512, rnn_type='LSTM', num_recurrent_layers=2, backbone='resnet50', normalize_visual_inputs=False, action_distribution=action_distribution, dim_actions=dim_actions) model.to(torch.device("cpu")) data_dict = OrderedDict() with open(WEIGHTS_PATH, 'r') as f: data_dict = json.load(f) model.load_state_dict({ k[len("actor_critic."):]: torch.tensor(v) for k, v in data_dict.items() if k.startswith("actor_critic.") }) return model
class DDPPOTrainer(PPOTrainer): # DD-PPO cuts rollouts short to mitigate the straggler effect # This, in theory, can cause some rollouts to be very short. # All rollouts contributed equally to the loss/model-update, # thus very short rollouts can be problematic. This threshold # limits the how short a short rollout can be as a fraction of the # max rollout length SHORT_ROLLOUT_THRESHOLD: float = 0.25 def __init__(self, config=None): interrupted_state = load_interrupted_state() if interrupted_state is not None: config = interrupted_state["config"] super().__init__(config) def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None: r"""Sets up actor critic and agent for DD-PPO. Args: ppo_cfg: config node with relevant params Returns: None """ logger.add_filehandler(self.config.LOG_FILE) self.actor_critic = PointNavResNetPolicy( observation_space=self.envs.observation_spaces[0], action_space=self.envs.action_spaces[0], hidden_size=ppo_cfg.hidden_size, rnn_type=self.config.RL.DDPPO.rnn_type, num_recurrent_layers=self.config.RL.DDPPO.num_recurrent_layers, backbone=self.config.RL.DDPPO.backbone, goal_sensor_uuid=self.config.TASK_CONFIG.TASK.GOAL_SENSOR_UUID, normalize_visual_inputs="rgb" in self.envs.observation_spaces[0].spaces, ) self.actor_critic.to(self.device) if (self.config.RL.DDPPO.pretrained_encoder or self.config.RL.DDPPO.pretrained): pretrained_state = torch.load( self.config.RL.DDPPO.pretrained_weights, map_location="cpu") if self.config.RL.DDPPO.pretrained: self.actor_critic.load_state_dict({ k[len("actor_critic."):]: v for k, v in pretrained_state["state_dict"].items() }) elif self.config.RL.DDPPO.pretrained_encoder: prefix = "actor_critic.net.visual_encoder." self.actor_critic.net.visual_encoder.load_state_dict({ k[len(prefix):]: v for k, v in pretrained_state["state_dict"].items() if k.startswith(prefix) }) if not self.config.RL.DDPPO.train_encoder: self._static_encoder = True for param in self.actor_critic.net.visual_encoder.parameters(): param.requires_grad_(False) if self.config.RL.DDPPO.reset_critic: nn.init.orthogonal_(self.actor_critic.critic.fc.weight) nn.init.constant_(self.actor_critic.critic.fc.bias, 0) self.agent = DDPPO( actor_critic=self.actor_critic, clip_param=ppo_cfg.clip_param, ppo_epoch=ppo_cfg.ppo_epoch, num_mini_batch=ppo_cfg.num_mini_batch, value_loss_coef=ppo_cfg.value_loss_coef, entropy_coef=ppo_cfg.entropy_coef, lr=ppo_cfg.lr, eps=ppo_cfg.eps, max_grad_norm=ppo_cfg.max_grad_norm, use_normalized_advantage=ppo_cfg.use_normalized_advantage, ) def train(self) -> None: r"""Main method for DD-PPO. Returns: None """ self.local_rank, tcp_store = init_distrib_slurm( self.config.RL.DDPPO.distrib_backend) add_signal_handlers() # Stores the number of workers that have finished their rollout num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store) num_rollouts_done_store.set("num_done", "0") self.world_rank = distrib.get_rank() self.world_size = distrib.get_world_size() self.config.defrost() self.config.TORCH_GPU_ID = self.local_rank self.config.SIMULATOR_GPU_ID = self.local_rank # Multiply by the number of simulators to make sure they also get unique seeds self.config.TASK_CONFIG.SEED += (self.world_rank * self.config.NUM_PROCESSES) self.config.freeze() random.seed(self.config.TASK_CONFIG.SEED) np.random.seed(self.config.TASK_CONFIG.SEED) torch.manual_seed(self.config.TASK_CONFIG.SEED) if torch.cuda.is_available(): self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) ppo_cfg = self.config.RL.PPO if (not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0): os.makedirs(self.config.CHECKPOINT_FOLDER) self._setup_actor_critic_agent(ppo_cfg) self.agent.init_distributed(find_unused_params=True) if self.world_rank == 0: logger.info("agent number of trainable parameters: {}".format( sum(param.numel() for param in self.agent.parameters() if param.requires_grad))) observations = self.envs.reset() batch = batch_obs(observations, device=self.device) obs_space = self.envs.observation_spaces[0] if self._static_encoder: self._encoder = self.actor_critic.net.visual_encoder obs_space = SpaceDict({ "visual_features": spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=self._encoder.output_shape, dtype=np.float32, ), **obs_space.spaces, }) with torch.no_grad(): batch["visual_features"] = self._encoder(batch) rollouts = RolloutStorage( ppo_cfg.num_steps, self.envs.num_envs, obs_space, self.envs.action_spaces[0], ppo_cfg.hidden_size, num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, ) rollouts.to(self.device) for sensor in rollouts.observations: rollouts.observations[sensor][0].copy_(batch[sensor]) # batch and observations may contain shared PyTorch CUDA # tensors. We must explicitly clear them here otherwise # they will be kept in memory for the entire duration of training! batch = None observations = None current_episode_reward = torch.zeros(self.envs.num_envs, 1, device=self.device) running_episode_stats = dict( count=torch.zeros(self.envs.num_envs, 1, device=self.device), reward=torch.zeros(self.envs.num_envs, 1, device=self.device), ) window_episode_stats = defaultdict( lambda: deque(maxlen=ppo_cfg.reward_window_size)) t_start = time.time() env_time = 0 pth_time = 0 count_steps = 0 count_checkpoints = 0 start_update = 0 prev_time = 0 lr_scheduler = LambdaLR( optimizer=self.agent.optimizer, lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), ) interrupted_state = load_interrupted_state() if interrupted_state is not None: self.agent.load_state_dict(interrupted_state["state_dict"]) self.agent.optimizer.load_state_dict( interrupted_state["optim_state"]) lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) requeue_stats = interrupted_state["requeue_stats"] env_time = requeue_stats["env_time"] pth_time = requeue_stats["pth_time"] count_steps = requeue_stats["count_steps"] count_checkpoints = requeue_stats["count_checkpoints"] start_update = requeue_stats["start_update"] prev_time = requeue_stats["prev_time"] with (TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs) if self.world_rank == 0 else contextlib.suppress()) as writer: for update in range(start_update, self.config.NUM_UPDATES): if ppo_cfg.use_linear_lr_decay: lr_scheduler.step() if ppo_cfg.use_linear_clip_decay: self.agent.clip_param = ppo_cfg.clip_param * linear_decay( update, self.config.NUM_UPDATES) if EXIT.is_set(): self.envs.close() if REQUEUE.is_set() and self.world_rank == 0: requeue_stats = dict( env_time=env_time, pth_time=pth_time, count_steps=count_steps, count_checkpoints=count_checkpoints, start_update=update, prev_time=(time.time() - t_start) + prev_time, ) save_interrupted_state( dict( state_dict=self.agent.state_dict(), optim_state=self.agent.optimizer.state_dict(), lr_sched_state=lr_scheduler.state_dict(), config=self.config, requeue_stats=requeue_stats, )) requeue_job() return count_steps_delta = 0 self.agent.eval() for step in range(ppo_cfg.num_steps): ( delta_pth_time, delta_env_time, delta_steps, ) = self._collect_rollout_step(rollouts, current_episode_reward, running_episode_stats) pth_time += delta_pth_time env_time += delta_env_time count_steps_delta += delta_steps # This is where the preemption of workers happens. If a # worker detects it will be a straggler, it preempts itself! if (step >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD ) and int(num_rollouts_done_store.get("num_done")) > ( self.config.RL.DDPPO.sync_frac * self.world_size): break num_rollouts_done_store.add("num_done", 1) self.agent.train() if self._static_encoder: self._encoder.eval() ( delta_pth_time, value_loss, action_loss, dist_entropy, ) = self._update_agent(ppo_cfg, rollouts) pth_time += delta_pth_time stats_ordering = list(sorted(running_episode_stats.keys())) stats = torch.stack( [running_episode_stats[k] for k in stats_ordering], 0) distrib.all_reduce(stats) for i, k in enumerate(stats_ordering): window_episode_stats[k].append(stats[i].clone()) stats = torch.tensor( [value_loss, action_loss, count_steps_delta], device=self.device, ) distrib.all_reduce(stats) count_steps += stats[2].item() if self.world_rank == 0: num_rollouts_done_store.set("num_done", "0") losses = [ stats[0].item() / self.world_size, stats[1].item() / self.world_size, ] deltas = { k: ((v[-1] - v[0]).sum().item() if len(v) > 1 else v[0].sum().item()) for k, v in window_episode_stats.items() } deltas["count"] = max(deltas["count"], 1.0) writer.add_scalar( "reward", deltas["reward"] / deltas["count"], count_steps, ) # Check to see if there are any metrics # that haven't been logged yet metrics = { k: v / deltas["count"] for k, v in deltas.items() if k not in {"reward", "count"} } if len(metrics) > 0: writer.add_scalars("metrics", metrics, count_steps) writer.add_scalars( "losses", {k: l for l, k in zip(losses, ["value", "policy"])}, count_steps, ) # log stats if update > 0 and update % self.config.LOG_INTERVAL == 0: logger.info("update: {}\tfps: {:.3f}\t".format( update, count_steps / ((time.time() - t_start) + prev_time), )) logger.info( "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" "frames: {}".format(update, env_time, pth_time, count_steps)) logger.info("Average window size: {} {}".format( len(window_episode_stats["count"]), " ".join( "{}: {:.3f}".format(k, v / deltas["count"]) for k, v in deltas.items() if k != "count"), )) # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: self.save_checkpoint( f"ckpt.{count_checkpoints}.pth", dict(step=count_steps), ) count_checkpoints += 1 self.envs.close()
class DDPPOAgent(Agent): def __init__(self, config: Config): if "ObjectNav" in config.TASK_CONFIG.TASK.TYPE: OBJECT_CATEGORIES_NUM = 20 spaces = { "objectgoal": Box( low=0, high=OBJECT_CATEGORIES_NUM, shape=(1,), dtype=np.int64), "compass": Box( low=-np.pi, high=np.pi, shape=(1,), dtype=np.float), "gps": Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(2,), dtype=np.float32,) } else: spaces = { "pointgoal": Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(2,), dtype=np.float32, ) } if config.INPUT_TYPE in ["depth", "rgbd"]: spaces["depth"] = Box( low=0, high=1, shape=(config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.HEIGHT, config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.WIDTH, 1), dtype=np.float32, ) if config.INPUT_TYPE in ["rgb", "rgbd"]: spaces["rgb"] = Box( low=0, high=255, shape=(config.TASK_CONFIG.SIMULATOR.RGB_SENSOR.HEIGHT, config.TASK_CONFIG.SIMULATOR.RGB_SENSOR.WIDTH, 3), dtype=np.uint8, ) observation_spaces = Dict(spaces) action_space = Discrete(len(config.TASK_CONFIG.TASK.POSSIBLE_ACTIONS)) self.device = torch.device("cuda:{}".format(config.TORCH_GPU_ID)) self.hidden_size = config.RL.PPO.hidden_size random.seed(config.RANDOM_SEED) torch.random.manual_seed(config.RANDOM_SEED) torch.backends.cudnn.deterministic = True self.actor_critic = PointNavResNetPolicy( observation_space=observation_spaces, action_space=action_space, hidden_size=self.hidden_size, normalize_visual_inputs="rgb" if config.INPUT_TYPE in ["rgb", "rgbd"] else False, ) self.actor_critic.to(self.device) if config.MODEL_PATH: ckpt = torch.load(config.MODEL_PATH, map_location=self.device) print(f"Checkpoint loaded: {config.MODEL_PATH}") # Filter only actor_critic weights self.actor_critic.load_state_dict( { k.replace("actor_critic.", ""): v for k, v in ckpt["state_dict"].items() if "actor_critic" in k } ) else: habitat.logger.error( "Model checkpoint wasn't loaded, evaluating " "a random model." ) self.test_recurrent_hidden_states = None self.not_done_masks = None self.prev_actions = None def reset(self): self.test_recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, 1, self.hidden_size, device=self.device ) self.not_done_masks = torch.zeros(1, 1, device=self.device) self.prev_actions = torch.zeros( 1, 1, dtype=torch.long, device=self.device ) def act(self, observations): batch = batch_obs([observations], device=self.device) with torch.no_grad(): _, action, _, self.test_recurrent_hidden_states = self.actor_critic.act( batch, self.test_recurrent_hidden_states, self.prev_actions, self.not_done_masks, deterministic=False, ) # Make masks not done till reset (end of episode) will be called self.not_done_masks.fill_(1.0) self.prev_actions.copy_(action) return action.item()
class DDPPOAgent(Agent): def __init__(self, config: Config): if "ObjectNav" in config.TASK_CONFIG.TASK.TYPE: OBJECT_CATEGORIES_NUM = 20 spaces = { "objectgoal": Box(low=0, high=OBJECT_CATEGORIES_NUM, shape=(1, ), dtype=np.int64), "compass": Box(low=-np.pi, high=np.pi, shape=(1, ), dtype=np.float), "gps": Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(2, ), dtype=np.float32, ) } else: spaces = { "pointgoal": Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(2, ), dtype=np.float32, ) } if config.INPUT_TYPE in ["depth", "rgbd"]: spaces["depth"] = Box( low=0, high=1, shape=(config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.HEIGHT, config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.WIDTH, 1), dtype=np.float32, ) if config.INPUT_TYPE in ["rgb", "rgbd"]: spaces["rgb"] = Box( low=0, high=255, shape=(config.TASK_CONFIG.SIMULATOR.RGB_SENSOR.HEIGHT, config.TASK_CONFIG.SIMULATOR.RGB_SENSOR.WIDTH, 3), dtype=np.uint8, ) observation_spaces = Dict(spaces) action_space = Discrete(len(config.TASK_CONFIG.TASK.POSSIBLE_ACTIONS)) self.device = torch.device("cuda:{}".format(config.TORCH_GPU_ID)) self.hidden_size = config.RL.PPO.hidden_size random.seed(config.RANDOM_SEED) np.random.seed(config.RANDOM_SEED) _seed_numba(config.RANDOM_SEED) torch.random.manual_seed(config.RANDOM_SEED) torch.backends.cudnn.deterministic = True policy_arguments = OrderedDict( observation_space=observation_spaces, action_space=action_space, hidden_size=self.hidden_size, rnn_type=config.RL.DDPPO.rnn_type, num_recurrent_layers=config.RL.DDPPO.num_recurrent_layers, backbone=config.RL.DDPPO.backbone, normalize_visual_inputs="rgb" if config.INPUT_TYPE in ["rgb", "rgbd"] else False, final_beta=None, start_beta=None, beta_decay_steps=None, decay_start_step=None, use_info_bot=False, use_odometry=False, ) if "ObjectNav" not in config.TASK_CONFIG.TASK.TYPE: policy_arguments[ "goal_sensor_uuid"] = config.TASK_CONFIG.TASK.GOAL_SENSOR_UUID self.actor_critic = PointNavResNetPolicy(**policy_arguments) self.actor_critic.to(self.device) self._encoder = self.actor_critic.net.visual_encoder if config.MODEL_PATH: ckpt = torch.load(config.MODEL_PATH, map_location=self.device) print(f"Checkpoint loaded: {config.MODEL_PATH}") # Filter only actor_critic weights self.actor_critic.load_state_dict({ k.replace("actor_critic.", ""): v for k, v in ckpt["state_dict"].items() if "actor_critic" in k }) else: habitat.logger.error("Model checkpoint wasn't loaded, evaluating " "a random model.") self.test_recurrent_hidden_states = None self.not_done_masks = None self.prev_actions = None self.final_action = False def convertPolarToCartesian(self, coords): rho = coords[0] theta = -coords[1] return np.array([rho * np.cos(theta), rho * np.sin(theta)], dtype=np.float32) def convertMaxDepth(self, obs): # min_depth = 0.1 # max_depth = 5 # obs = obs * (10 - 0.1) + 0.1 # if isinstance(obs, np.ndarray): # obs = np.clip(obs, min_depth, max_depth) # else: # obs = obs.clamp(min_depth, max_depth) # obs = (obs - min_depth) / ( # max_depth - min_depth # ) return obs def reset(self): self.test_recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, 1, self.hidden_size, device=self.device) self.not_done_masks = torch.zeros(1, 1, device=self.device) self.prev_actions = torch.zeros(1, 1, dtype=torch.long, device=self.device) self.prev_visual_features = None self.final_action = False def act(self, observations): observations["pointgoal"] = self.convertPolarToCartesian( observations["pointgoal"]) observations["depth"] = self.convertMaxDepth(observations["depth"]) batch = batch_obs([observations], device=self.device) batch["visual_features"] = self._encoder(batch) if self.prev_visual_features == None: batch["prev_visual_features"] = torch.zeros_like( batch["visual_features"]) else: batch["prev_visual_features"] = self.prev_visual_features with torch.no_grad(): step_batch = batch _, action, _, self.test_recurrent_hidden_states = self.actor_critic.act( batch, None, self.test_recurrent_hidden_states, self.prev_actions, self.not_done_masks, deterministic=False, ) # Make masks not done till reset (end of episode) will be called self.not_done_masks.fill_(1.0) self.prev_actions.copy_(action) self.prev_visual_features = step_batch["visual_features"] # if self.final_action: # return 0 # if action.item() == 0: # self.final_action = True # return 1 return action.item()