def forward(self, observations, rnn_hidden_states, prev_actions, masks): r""" instruction_embedding: [batch_size x INSTRUCTION_ENCODER.output_size] depth_embedding: [batch_size x DEPTH_ENCODER.output_size] rgb_embedding: [batch_size x RGB_ENCODER.output_size] """ ### instruction # instruction_embedding = self.instruction_encoder(observations) instruction_embedding = self._get_bert_embedding(observations) depth_embedding = self.depth_encoder(observations) rgb_embedding = self.rgb_encoder(observations) # print("depth_embedding: ", depth_embedding) # print("depth_embedding: ", depth_embedding.size()) if self.model_config.ablate_instruction: instruction_embedding = instruction_embedding * 0 if self.model_config.ablate_depth: depth_embedding = depth_embedding * 0 if self.model_config.ablate_rgb: rgb_embedding = rgb_embedding * 0 x = torch.cat([instruction_embedding, depth_embedding, rgb_embedding], dim=1) if self.model_config.SEQ2SEQ.use_prev_action: prev_actions_embedding = self.prev_action_embedding( ((prev_actions.float() + 1) * masks).long().view(-1) ) x = torch.cat([x, prev_actions_embedding], dim=1) x, rnn_hidden_states = self.state_encoder(x, rnn_hidden_states, masks) if self.model_config.PROGRESS_MONITOR.use and AuxLosses.is_active(): progress_hat = torch.tanh(self.progress_monitor(x)) progress_loss = F.mse_loss( progress_hat.squeeze(1), observations["progress"], reduction="none" ) AuxLosses.register_loss( "progress_monitor", progress_loss, self.model_config.PROGRESS_MONITOR.alpha, ) return x, rnn_hidden_states
def _update_agent( self, observations, prev_actions, not_done_masks, corrected_actions, weights ): T, N = corrected_actions.size() self.optimizer.zero_grad() recurrent_hidden_states = torch.zeros( self.actor_critic.net.num_recurrent_layers, N, self.config.MODEL.STATE_ENCODER.hidden_size, device=self.device, ) AuxLosses.clear() distribution = self.actor_critic.build_distribution( observations, recurrent_hidden_states, prev_actions, not_done_masks ) logits = distribution.logits logits = logits.view(T, N, -1) action_loss = F.cross_entropy( logits.permute(0, 2, 1), corrected_actions, reduction="none" ) action_loss = ((weights * action_loss).sum(0) / weights.sum(0)).mean() aux_mask = (weights > 0).view(-1) aux_loss = AuxLosses.reduce(aux_mask) loss = action_loss + aux_loss loss.backward() self.optimizer.step() if isinstance(aux_loss, torch.Tensor): return loss.item(), action_loss.item(), aux_loss.item() else: return loss.item(), action_loss.item(), aux_loss
def train(self) -> None: r"""Main method for training DAgger. Returns: None """ os.makedirs(self.lmdb_features_dir, exist_ok=True) os.makedirs(self.config.CHECKPOINT_FOLDER, exist_ok=True) if self.config.DAGGER.PRELOAD_LMDB_FEATURES: try: lmdb.open(self.lmdb_features_dir, readonly=True) except lmdb.Error as err: logger.error( "Cannot open database for teacher forcing preload.") raise err else: with lmdb.open(self.lmdb_features_dir, map_size=int(self.config.DAGGER.LMDB_MAP_SIZE) ) as lmdb_env, lmdb_env.begin(write=True) as txn: txn.drop(lmdb_env.open_db()) split = self.config.TASK_CONFIG.DATASET.SPLIT self.config.defrost() self.config.TASK_CONFIG.TASK.NDTW.SPLIT = split self.config.TASK_CONFIG.TASK.SDTW.SPLIT = split # if doing teacher forcing, don't switch the scene until it is complete if self.config.DAGGER.P == 1.0: self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = ( -1) self.config.freeze() if self.config.DAGGER.PRELOAD_LMDB_FEATURES: # when preloadeding features, its quicker to just load one env as we just # need the observation space from it. single_proc_config = self.config.clone() single_proc_config.defrost() single_proc_config.NUM_PROCESSES = 1 single_proc_config.freeze() self.envs = construct_envs(single_proc_config, get_env_class(self.config.ENV_NAME)) else: self.envs = construct_envs(self.config, get_env_class(self.config.ENV_NAME)) self._setup_actor_critic_agent( self.config.MODEL, self.config.DAGGER.LOAD_FROM_CKPT, self.config.DAGGER.CKPT_TO_LOAD, ) logger.info("agent number of parameters: {}".format( sum(param.numel() for param in self.actor_critic.parameters()))) logger.info("agent number of trainable parameters: {}".format( sum(p.numel() for p in self.actor_critic.parameters() if p.requires_grad))) if self.config.DAGGER.PRELOAD_LMDB_FEATURES: self.envs.close() del self.envs self.envs = None with TensorboardWriter(self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs, purge_step=0) as writer: for dagger_it in range(self.config.DAGGER.ITERATIONS): step_id = 0 if not self.config.DAGGER.PRELOAD_LMDB_FEATURES: self._update_dataset(dagger_it + ( 1 if self.config.DAGGER.LOAD_FROM_CKPT else 0)) if torch.cuda.is_available(): with torch.cuda.device(self.device): torch.cuda.empty_cache() gc.collect() dataset = IWTrajectoryDataset( self.lmdb_features_dir, self.config.DAGGER.USE_IW, inflection_weight_coef=self.config.MODEL. inflection_weight_coef, lmdb_map_size=self.config.DAGGER.LMDB_MAP_SIZE, batch_size=self.config.DAGGER.BATCH_SIZE, ) AuxLosses.activate() for epoch in tqdm.trange(self.config.DAGGER.EPOCHS): diter = torch.utils.data.DataLoader( dataset, batch_size=self.config.DAGGER.BATCH_SIZE, shuffle=False, collate_fn=collate_fn, pin_memory=False, drop_last=True, # drop last batch if smaller num_workers=0, ) for batch in tqdm.tqdm(diter, total=dataset.length // dataset.batch_size, leave=False): ( observations_batch, prev_actions_batch, not_done_masks, corrected_actions_batch, weights_batch, ) = batch observations_batch = { k: v.to(device=self.device, non_blocking=True) for k, v in observations_batch.items() } try: loss, action_loss, aux_loss = self._update_agent( observations_batch, prev_actions_batch.to(device=self.device, non_blocking=True), not_done_masks.to(device=self.device, non_blocking=True), corrected_actions_batch.to(device=self.device, non_blocking=True), weights_batch.to(device=self.device, non_blocking=True), ) except: logger.info( "ERROR: failed to update agent. Updating agent with batch size of 1." ) loss, action_loss, aux_loss = 0, 0, 0 prev_actions_batch = prev_actions_batch.cpu() not_done_masks = not_done_masks.cpu() corrected_actions_batch = corrected_actions_batch.cpu( ) weights_batch = weights_batch.cpu() observations_batch = { k: v.cpu() for k, v in observations_batch.items() } for i in range(not_done_masks.size(0)): output = self._update_agent( { k: v[i].to(device=self.device, non_blocking=True) for k, v in observations_batch.items() }, prev_actions_batch[i].to( device=self.device, non_blocking=True), not_done_masks[i].to(device=self.device, non_blocking=True), corrected_actions_batch[i].to( device=self.device, non_blocking=True), weights_batch[i].to(device=self.device, non_blocking=True), ) loss += output[0] action_loss += output[1] aux_loss += output[2] logger.info(f"train_loss: {loss}") logger.info(f"train_action_loss: {action_loss}") logger.info(f"train_aux_loss: {aux_loss}") logger.info(f"Batches processed: {step_id}.") logger.info( f"On DAgger iter {dagger_it}, Epoch {epoch}.") writer.add_scalar(f"train_loss_iter_{dagger_it}", loss, step_id) writer.add_scalar( f"train_action_loss_iter_{dagger_it}", action_loss, step_id) writer.add_scalar(f"train_aux_loss_iter_{dagger_it}", aux_loss, step_id) step_id += 1 self.save_checkpoint( f"ckpt.{dagger_it * self.config.DAGGER.EPOCHS + epoch}.pth" ) AuxLosses.deactivate()
def forward(self, observations, rnn_hidden_states, prev_actions, masks): r""" instruction_embedding: [batch_size x INSTRUCTION_ENCODER.output_size] depth_embedding: [batch_size x DEPTH_ENCODER.output_size] rgb_embedding: [batch_size x RGB_ENCODER.output_size] """ instruction_embedding = self.instruction_encoder(observations) depth_embedding = self.depth_encoder(observations) depth_embedding = torch.flatten(depth_embedding, 2) rgb_embedding = self.rgb_encoder(observations) rgb_embedding = torch.flatten(rgb_embedding, 2) prev_actions = self.prev_action_embedding( ((prev_actions.float() + 1) * masks).long().view(-1) ) if self.model_config.ablate_instruction: instruction_embedding = instruction_embedding * 0 if self.model_config.ablate_depth: depth_embedding = depth_embedding * 0 if self.model_config.ablate_rgb: rgb_embedding = rgb_embedding * 0 if self.rcm_state_encoder: ( state, rnn_hidden_states[0 : self.state_encoder.num_recurrent_layers], ) = self.state_encoder( rgb_embedding, depth_embedding, prev_actions, rnn_hidden_states[0 : self.state_encoder.num_recurrent_layers], masks, ) else: rgb_in = self.rgb_linear(rgb_embedding) depth_in = self.depth_linear(depth_embedding) state_in = torch.cat([rgb_in, depth_in, prev_actions], dim=1) ( state, rnn_hidden_states[0 : self.state_encoder.num_recurrent_layers], ) = self.state_encoder( state_in, rnn_hidden_states[0 : self.state_encoder.num_recurrent_layers], masks, ) text_state_q = self.state_q(state) text_state_k = self.text_k(instruction_embedding) text_mask = (instruction_embedding == 0.0).all(dim=1) text_embedding = self._attn( text_state_q, text_state_k, instruction_embedding, text_mask ) rgb_k, rgb_v = torch.split( self.rgb_kv(rgb_embedding), self._hidden_size // 2, dim=1 ) depth_k, depth_v = torch.split( self.depth_kv(depth_embedding), self._hidden_size // 2, dim=1 ) text_q = self.text_q(text_embedding) rgb_embedding = self._attn(text_q, rgb_k, rgb_v) depth_embedding = self._attn(text_q, depth_k, depth_v) x = torch.cat( [state, text_embedding, rgb_embedding, depth_embedding, prev_actions], dim=1 ) x = self.second_state_compress(x) ( x, rnn_hidden_states[self.state_encoder.num_recurrent_layers :], ) = self.second_state_encoder( x, rnn_hidden_states[self.state_encoder.num_recurrent_layers :], masks ) if self.model_config.PROGRESS_MONITOR.use and AuxLosses.is_active(): progress_hat = torch.tanh(self.progress_monitor(x)) progress_loss = F.mse_loss( progress_hat.squeeze(1), observations["progress"], reduction="none" ) AuxLosses.register_loss( "progress_monitor", progress_loss, self.model_config.PROGRESS_MONITOR.alpha, ) return x, rnn_hidden_states
config.MODEL.TORCH_GPU_ID = config.TORCH_GPU_ID config.freeze() action_space = spaces.Discrete(4) policy = CMAPolicy(observation_space, action_space, config.MODEL).to(device) dummy_instruction = torch.randint(1, 4, size=(4 * 2, 8), device=device) dummy_instruction[:, 5:] = 0 dummy_instruction[0, 2:] = 0 obs = dict( rgb=torch.randn(4 * 2, 224, 224, 3, device=device), depth=torch.randn(4 * 2, 256, 256, 1, device=device), instruction=dummy_instruction, progress=torch.randn(4 * 2, 1, device=device), ) hidden_states = torch.randn( policy.net.state_encoder.num_recurrent_layers, 2, policy.net._hidden_size, device=device, ) prev_actions = torch.randint(0, 3, size=(4 * 2, 1), device=device) masks = torch.ones(4 * 2, 1, device=device) AuxLosses.activate() policy.evaluate_actions(obs, hidden_states, prev_actions, masks, prev_actions)