def train(self) -> None: r"""Main method for training DAgger. Returns: None """ os.makedirs(self.lmdb_features_dir, exist_ok=True) os.makedirs(self.config.CHECKPOINT_FOLDER, exist_ok=True) if self.config.DAGGER.PRELOAD_LMDB_FEATURES: try: lmdb.open(self.lmdb_features_dir, readonly=True) except lmdb.Error as err: logger.error( "Cannot open database for teacher forcing preload.") raise err else: with lmdb.open(self.lmdb_features_dir, map_size=int(self.config.DAGGER.LMDB_MAP_SIZE) ) as lmdb_env, lmdb_env.begin(write=True) as txn: txn.drop(lmdb_env.open_db()) split = self.config.TASK_CONFIG.DATASET.SPLIT self.config.defrost() self.config.TASK_CONFIG.TASK.NDTW.SPLIT = split self.config.TASK_CONFIG.TASK.SDTW.SPLIT = split # if doing teacher forcing, don't switch the scene until it is complete if self.config.DAGGER.P == 1.0: self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = ( -1) self.config.freeze() if self.config.DAGGER.PRELOAD_LMDB_FEATURES: # when preloadeding features, its quicker to just load one env as we just # need the observation space from it. single_proc_config = self.config.clone() single_proc_config.defrost() single_proc_config.NUM_PROCESSES = 1 single_proc_config.freeze() self.envs = construct_envs(single_proc_config, get_env_class(self.config.ENV_NAME)) else: self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) self._setup_actor_critic_agent( self.config.MODEL, self.config.DAGGER.LOAD_FROM_CKPT, self.config.DAGGER.CKPT_TO_LOAD, ) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.actor_critic.parameters()))) logger.info("agent number of trainable parameters: {}".format( sum(p.numel() for p in self.actor_critic.parameters() if p.requires_grad))) if self.config.DAGGER.PRELOAD_LMDB_FEATURES: self.envs.close() del self.envs self.envs = None with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs, purge_step=0) as writer: for dagger_it in range(self.config.DAGGER.ITERATIONS): step_id = 0 if not self.config.DAGGER.PRELOAD_LMDB_FEATURES: self._update_dataset(dagger_it + ( 1 if self.config.DAGGER.LOAD_FROM_CKPT else 0)) if torch.cuda.is_available(): with torch.cuda.device(self.device): torch.cuda.empty_cache() gc.collect() dataset = IWTrajectoryDataset( self.lmdb_features_dir, self.config.DAGGER.USE_IW, inflection_weight_coef=self.config.MODEL. inflection_weight_coef, lmdb_map_size=self.config.DAGGER.LMDB_MAP_SIZE, batch_size=self.config.DAGGER.BATCH_SIZE, ) AuxLosses.activate() for epoch in tqdm.trange(self.config.DAGGER.EPOCHS): diter = torch.utils.data.DataLoader( dataset, batch_size=self.config.DAGGER.BATCH_SIZE, shuffle=False, collate_fn=collate_fn, pin_memory=False, drop_last=True, # drop last batch if smaller num_workers=0, ) for batch in tqdm.tqdm(diter, total=dataset.length // dataset.batch_size, leave=False): ( observations_batch, prev_actions_batch, not_done_masks, corrected_actions_batch, weights_batch, ) = batch observations_batch = { k: v.to(device=self.device, non_blocking=True) for k, v in observations_batch.items() } try: loss, action_loss, aux_loss = self._update_agent( observations_batch, prev_actions_batch.to(device=self.device, non_blocking=True), not_done_masks.to(device=self.device, non_blocking=True), corrected_actions_batch.to(device=self.device, non_blocking=True), weights_batch.to(device=self.device, non_blocking=True), ) except: logger.info( "ERROR: failed to update agent. Updating agent with batch size of 1." ) loss, action_loss, aux_loss = 0, 0, 0 prev_actions_batch = prev_actions_batch.cpu() not_done_masks = not_done_masks.cpu() corrected_actions_batch = corrected_actions_batch.cpu( ) weights_batch = weights_batch.cpu() observations_batch = { k: v.cpu() for k, v in observations_batch.items() } for i in range(not_done_masks.size(0)): output = self._update_agent( { k: v[i].to(device=self.device, non_blocking=True) for k, v in observations_batch.items() }, prev_actions_batch[i].to( device=self.device, non_blocking=True), not_done_masks[i].to(device=self.device, non_blocking=True), corrected_actions_batch[i].to( device=self.device, non_blocking=True), weights_batch[i].to(device=self.device, non_blocking=True), ) loss += output[0] action_loss += output[1] aux_loss += output[2] logger.info(f"train_loss: {loss}") logger.info(f"train_action_loss: {action_loss}") logger.info(f"train_aux_loss: {aux_loss}") logger.info(f"Batches processed: {step_id}.") logger.info( f"On DAgger iter {dagger_it}, Epoch {epoch}.") writer.add_scalar(f"train_loss_iter_{dagger_it}", loss, step_id) writer.add_scalar( f"train_action_loss_iter_{dagger_it}", action_loss, step_id) writer.add_scalar(f"train_aux_loss_iter_{dagger_it}", aux_loss, step_id) step_id += 1 self.save_checkpoint( f"ckpt.{dagger_it * self.config.DAGGER.EPOCHS + epoch}.pth" ) AuxLosses.deactivate()
config.MODEL.TORCH_GPU_ID = config.TORCH_GPU_ID config.freeze() action_space = spaces.Discrete(4) policy = CMAPolicy(observation_space, action_space, config.MODEL).to(device) dummy_instruction = torch.randint(1, 4, size=(4 * 2, 8), device=device) dummy_instruction[:, 5:] = 0 dummy_instruction[0, 2:] = 0 obs = dict( rgb=torch.randn(4 * 2, 224, 224, 3, device=device), depth=torch.randn(4 * 2, 256, 256, 1, device=device), instruction=dummy_instruction, progress=torch.randn(4 * 2, 1, device=device), ) hidden_states = torch.randn( policy.net.state_encoder.num_recurrent_layers, 2, policy.net._hidden_size, device=device, ) prev_actions = torch.randint(0, 3, size=(4 * 2, 1), device=device) masks = torch.ones(4 * 2, 1, device=device) AuxLosses.activate() policy.evaluate_actions(obs, hidden_states, prev_actions, masks, prev_actions)