def learn( self, model: EmmentalModel, dataloaders: List[EmmentalDataLoader] ) -> None: r"""The learning procedure of emmental MTL. Args: model(EmmentalModel): The emmental model that needs to learn. dataloaders(List[EmmentalDataLoader]): a list of dataloaders used to learn the model. """ # Generate the list of dataloaders for learning process train_split = Meta.config["learner_config"]["train_split"] if isinstance(train_split, str): train_split = [train_split] train_dataloaders = [ dataloader for dataloader in dataloaders if dataloader.split in train_split ] if not train_dataloaders: raise ValueError( f"Cannot find the specified train_split " f'{Meta.config["learner_config"]["train_split"]} in dataloaders.' ) # Set up task_scheduler self._set_task_scheduler() # Calculate the total number of batches per epoch self.n_batches_per_epoch = self.task_scheduler.get_num_batches( train_dataloaders ) # Set up logging manager self._set_logging_manager() # Set up optimizer self._set_optimizer(model) # Set up lr_scheduler self._set_lr_scheduler(model) # Set to training mode model.train() if Meta.config["meta_config"]["verbose"]: logger.info(f"Start learning...") self.metrics: Dict[str, float] = dict() self._reset_losses() for epoch_num in range(Meta.config["learner_config"]["n_epochs"]): batches = tqdm( enumerate(self.task_scheduler.get_batches(train_dataloaders, model)), total=self.n_batches_per_epoch, disable=(not Meta.config["meta_config"]["verbose"]), desc=f"Epoch {epoch_num}:", ) for batch_num, batch in batches: # Covert single batch into a batch list if not isinstance(batch, list): batch = [batch] total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num batch_size = 0 # Set gradients of all model parameters to zero self.optimizer.zero_grad() for uids, X_dict, Y_dict, task_to_label_dict, data_name, split in batch: batch_size += len(next(iter(Y_dict.values()))) # Perform forward pass and calcualte the loss and count uid_dict, loss_dict, prob_dict, gold_dict = model( uids, X_dict, Y_dict, task_to_label_dict ) # Update running loss and count for task_name in uid_dict.keys(): identifier = f"{task_name}/{data_name}/{split}" self.running_uids[identifier].extend(uid_dict[task_name]) self.running_losses[identifier] += ( loss_dict[task_name].item() * len(uid_dict[task_name]) if len(loss_dict[task_name].size()) == 0 else torch.sum(loss_dict[task_name]).item() ) self.running_probs[identifier].extend(prob_dict[task_name]) self.running_golds[identifier].extend(gold_dict[task_name]) # Skip the backward pass if no loss is calcuated if not loss_dict: continue # Calculate the average loss loss = sum( [ model.weights[task_name] * task_loss if len(task_loss.size()) == 0 else torch.mean(model.weights[task_name] * task_loss) for task_name, task_loss in loss_dict.items() ] ) # Perform backward pass to calculate gradients loss.backward() # type: ignore # Clip gradient norm if Meta.config["learner_config"]["optimizer_config"]["grad_clip"]: torch.nn.utils.clip_grad_norm_( model.parameters(), Meta.config["learner_config"]["optimizer_config"]["grad_clip"], ) # Update the parameters self.optimizer.step() self.metrics.update(self._logging(model, dataloaders, batch_size)) batches.set_postfix(self.metrics) # Update lr using lr scheduler self._update_lr_scheduler(model, total_batch_num, self.metrics) model = self.logging_manager.close(model)
def learn(self, model: EmmentalModel, dataloaders: List[EmmentalDataLoader]) -> None: """Learning procedure of emmental MTL. Args: model: The emmental model that needs to learn. dataloaders: A list of dataloaders used to learn the model. """ start_time = time.time() # Generate the list of dataloaders for learning process train_split = Meta.config["learner_config"]["train_split"] if isinstance(train_split, str): train_split = [train_split] train_dataloaders = [ dataloader for dataloader in dataloaders if dataloader.split in train_split ] if not train_dataloaders: raise ValueError( f"Cannot find the specified train_split " f'{Meta.config["learner_config"]["train_split"]} in dataloaders.' ) # Set up task_scheduler self._set_task_scheduler() # Calculate the total number of batches per epoch self.n_batches_per_epoch: int = self.task_scheduler.get_num_batches( train_dataloaders) if self.n_batches_per_epoch == 0: logger.info("No batches in training dataloaders, existing...") return # Set up learning counter self._set_learning_counter() # Set up logging manager self._set_logging_manager() # Set up wandb watch model if (Meta.config["logging_config"]["writer_config"]["writer"] == "wandb" and Meta.config["logging_config"]["writer_config"] ["wandb_watch_model"]): if Meta.config["logging_config"]["writer_config"][ "wandb_model_watch_freq"]: wandb.watch( model, log_freq=Meta.config["logging_config"]["writer_config"] ["wandb_model_watch_freq"], ) else: wandb.watch(model) # Set up optimizer self._set_optimizer(model) # Set up lr_scheduler self._set_lr_scheduler(model) if Meta.config["learner_config"]["fp16"]: try: from apex import amp # type: ignore except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to " "use fp16 training.") logger.info( f"Modeling training with 16-bit (mixed) precision " f"and {Meta.config['learner_config']['fp16_opt_level']} opt level." ) model, self.optimizer = amp.initialize( model, self.optimizer, opt_level=Meta.config["learner_config"]["fp16_opt_level"], ) # Multi-gpu training (after apex fp16 initialization) if (Meta.config["learner_config"]["local_rank"] == -1 and Meta.config["model_config"]["dataparallel"]): model._to_dataparallel() # Distributed training (after apex fp16 initialization) if Meta.config["learner_config"]["local_rank"] != -1: model._to_distributed_dataparallel() # Set to training mode model.train() if Meta.config["meta_config"]["verbose"]: logger.info("Start learning...") self.metrics: Dict[str, float] = dict() self._reset_losses() # Set gradients of all model parameters to zero self.optimizer.zero_grad() batch_iterator = self.task_scheduler.get_batches( train_dataloaders, model) for epoch_num in range(self.start_epoch, self.end_epoch): for train_dataloader in train_dataloaders: # Set epoch for distributed sampler if isinstance(train_dataloader, DataLoader) and isinstance( train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch_num) step_pbar = tqdm( range(self.start_step, self.end_step), desc=f"Step {self.start_step + 1}/{self.end_step}" if self.use_step_base_counter else f"Epoch {epoch_num + 1}/{self.end_epoch}", disable=not Meta.config["meta_config"]["verbose"] or Meta.config["learner_config"]["local_rank"] not in [-1, 0], ) for step_num in step_pbar: if self.use_step_base_counter: step_pbar.set_description( f"Step {step_num + 1}/{self.total_steps}") step_pbar.refresh() try: batch = next(batch_iterator) except StopIteration: batch_iterator = self.task_scheduler.get_batches( train_dataloaders, model) batch = next(batch_iterator) # Check if skip the current batch if epoch_num < self.start_train_epoch or ( epoch_num == self.start_train_epoch and step_num < self.start_train_step): continue # Covert single batch into a batch list if not isinstance(batch, list): batch = [batch] total_step_num = epoch_num * self.n_batches_per_epoch + step_num batch_size = 0 for _batch in batch: batch_size += len(_batch.uids) # Perform forward pass and calcualte the loss and count uid_dict, loss_dict, prob_dict, gold_dict = model( _batch.uids, _batch.X_dict, _batch.Y_dict, _batch.task_to_label_dict, return_probs=Meta.config["learner_config"] ["online_eval"], return_action_outputs=False, ) # Update running loss and count for task_name in uid_dict.keys(): identifier = f"{task_name}/{_batch.data_name}/{_batch.split}" self.running_uids[identifier].extend( uid_dict[task_name]) self.running_losses[identifier] += ( loss_dict[task_name].item() * len(uid_dict[task_name]) if len(loss_dict[task_name].size()) == 0 else torch.sum(loss_dict[task_name]).item() ) * model.task_weights[task_name] if (Meta.config["learner_config"]["online_eval"] and prob_dict and gold_dict): self.running_probs[identifier].extend( prob_dict[task_name]) self.running_golds[identifier].extend( gold_dict[task_name]) # Calculate the average loss loss = sum([ model.task_weights[task_name] * task_loss if len(task_loss.size()) == 0 else torch.mean(model.task_weights[task_name] * task_loss) for task_name, task_loss in loss_dict.items() ]) # Perform backward pass to calculate gradients if Meta.config["learner_config"]["fp16"]: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # type: ignore if (total_step_num + 1) % Meta.config["learner_config"]["optimizer_config"][ "gradient_accumulation_steps"] == 0 or ( step_num + 1 == self.end_step and epoch_num + 1 == self.end_epoch): # Clip gradient norm if Meta.config["learner_config"]["optimizer_config"][ "grad_clip"]: if Meta.config["learner_config"]["fp16"]: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), Meta.config["learner_config"] ["optimizer_config"]["grad_clip"], ) else: torch.nn.utils.clip_grad_norm_( model.parameters(), Meta.config["learner_config"] ["optimizer_config"]["grad_clip"], ) # Update the parameters self.optimizer.step() # Set gradients of all model parameters to zero self.optimizer.zero_grad() if Meta.config["learner_config"]["local_rank"] in [-1, 0]: self.metrics.update( self._logging(model, dataloaders, batch_size)) step_pbar.set_postfix(self.metrics) # Update lr using lr scheduler self._update_lr_scheduler(model, total_step_num, self.metrics) step_pbar.close() if Meta.config["learner_config"]["local_rank"] in [-1, 0]: model = self.logging_manager.close(model) logger.info( f"Total learning time: {time.time() - start_time} seconds.")
def _logging( self, model: EmmentalModel, dataloaders: List[EmmentalDataLoader], batch_size: int, ) -> Dict[str, float]: r"""Checking if it's time to evaluting or checkpointing. Args: model(EmmentalModel): The model to log. dataloaders(List[EmmentalDataLoader]): The data to evaluate. batch_size(int): Batch size. Returns: dict: The score dict. """ # Switch to eval mode for evaluation model.eval() metric_dict = dict() self.logging_manager.update(batch_size) # Log the loss and lr metric_dict.update(self._aggregate_running_metrics(model)) # Evaluate the model and log the metric trigger_evaluation = self.logging_manager.trigger_evaluation() if trigger_evaluation: # Log task specific metric metric_dict.update( self._evaluate( model, dataloaders, Meta.config["learner_config"]["valid_split"] ) ) self.logging_manager.write_log(metric_dict) self._reset_losses() # Log metric dict every trigger evaluation time or full epoch if Meta.config["meta_config"]["verbose"] and ( trigger_evaluation or self.logging_manager.epoch_total == int(self.logging_manager.epoch_total) ): logger.info( f"{self.logging_manager.counter_unit.capitalize()}: " f"{self.logging_manager.unit_total:.2f} {metric_dict}" ) # Checkpoint the model if self.logging_manager.trigger_checkpointing(): self.logging_manager.checkpoint_model( model, self.optimizer, self.lr_scheduler, metric_dict ) self.logging_manager.write_log(metric_dict) self._reset_losses() # Switch to train mode model.train() return metric_dict
def learn(self, model: EmmentalModel, dataloaders: List[EmmentalDataLoader]) -> None: """Learning procedure of emmental MTL. Args: model: The emmental model that needs to learn. dataloaders: A list of dataloaders used to learn the model. """ # Generate the list of dataloaders for learning process start_time = time.time() train_split = Meta.config["learner_config"]["train_split"] if isinstance(train_split, str): train_split = [train_split] train_dataloaders = [ dataloader for dataloader in dataloaders if dataloader.split in train_split ] if not train_dataloaders: raise ValueError( f"Cannot find the specified train_split " f'{Meta.config["learner_config"]["train_split"]} in dataloaders.' ) # Set up task_scheduler self._set_task_scheduler() # Calculate the total number of batches per epoch self.n_batches_per_epoch = self.task_scheduler.get_num_batches( train_dataloaders) # Set up logging manager self._set_logging_manager() # Set up optimizer self._set_optimizer(model) # Set up lr_scheduler self._set_lr_scheduler(model) if Meta.config["learner_config"]["fp16"]: try: from apex import amp # type: ignore except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to " "use fp16 training.") logger.info( f"Modeling training with 16-bit (mixed) precision " f"and {Meta.config['learner_config']['fp16_opt_level']} opt level." ) model, self.optimizer = amp.initialize( model, self.optimizer, opt_level=Meta.config["learner_config"]["fp16_opt_level"], ) # Multi-gpu training (after apex fp16 initialization) if (Meta.config["learner_config"]["local_rank"] == -1 and Meta.config["model_config"]["dataparallel"]): model._to_dataparallel() # Distributed training (after apex fp16 initialization) if Meta.config["learner_config"]["local_rank"] != -1: model._to_distributed_dataparallel() # Set to training mode model.train() if Meta.config["meta_config"]["verbose"]: logger.info("Start learning...") self.metrics: Dict[str, float] = dict() self._reset_losses() # Set gradients of all model parameters to zero self.optimizer.zero_grad() for epoch_num in range(Meta.config["learner_config"]["n_epochs"]): batches = tqdm( enumerate( self.task_scheduler.get_batches(train_dataloaders, model)), total=self.n_batches_per_epoch, disable=(not Meta.config["meta_config"]["verbose"] or Meta.config["learner_config"]["local_rank"] not in [-1, 0]), desc=f"Epoch {epoch_num}:", ) for batch_num, batch in batches: # Covert single batch into a batch list if not isinstance(batch, list): batch = [batch] total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num batch_size = 0 for uids, X_dict, Y_dict, task_to_label_dict, data_name, split in batch: batch_size += len(next(iter(Y_dict.values()))) # Perform forward pass and calcualte the loss and count uid_dict, loss_dict, prob_dict, gold_dict = model( uids, X_dict, Y_dict, task_to_label_dict) # Update running loss and count for task_name in uid_dict.keys(): identifier = f"{task_name}/{data_name}/{split}" self.running_uids[identifier].extend( uid_dict[task_name]) self.running_losses[identifier] += ( loss_dict[task_name].item() * len(uid_dict[task_name]) if len(loss_dict[task_name].size()) == 0 else torch.sum(loss_dict[task_name]).item()) self.running_probs[identifier].extend( prob_dict[task_name]) self.running_golds[identifier].extend( gold_dict[task_name]) # Skip the backward pass if no loss is calcuated if not loss_dict: continue # Calculate the average loss loss = sum([ model.weights[task_name] * task_loss if len(task_loss.size()) == 0 else torch.mean(model.weights[task_name] * task_loss) for task_name, task_loss in loss_dict.items() ]) # Perform backward pass to calculate gradients if Meta.config["learner_config"]["fp16"]: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # type: ignore if (total_batch_num + 1) % Meta.config["learner_config"]["optimizer_config"][ "gradient_accumulation_steps"] == 0 or ( batch_num + 1 == self.n_batches_per_epoch and epoch_num + 1 == Meta.config["learner_config"]["n_epochs"]): # Clip gradient norm if Meta.config["learner_config"]["optimizer_config"][ "grad_clip"]: if Meta.config["learner_config"]["fp16"]: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), Meta.config["learner_config"] ["optimizer_config"]["grad_clip"], ) else: torch.nn.utils.clip_grad_norm_( model.parameters(), Meta.config["learner_config"] ["optimizer_config"]["grad_clip"], ) # Update the parameters self.optimizer.step() # Set gradients of all model parameters to zero self.optimizer.zero_grad() if Meta.config["learner_config"]["local_rank"] in [-1, 0]: self.metrics.update( self._logging(model, dataloaders, batch_size)) batches.set_postfix(self.metrics) # Update lr using lr scheduler self._update_lr_scheduler(model, total_batch_num, self.metrics) if Meta.config["learner_config"]["local_rank"] in [-1, 0]: model = self.logging_manager.close(model) logger.info( f"Total learning time: {time.time() - start_time} seconds.")