def evaluate(model: Model, instances: Iterable[Instance], task_name: str, data_iterator: DataIterator, cuda_device: int) -> Dict[str, Any]: """ Evaluate a model for a particular tasks (usually after training). Parameters ---------- model : ``allennlp.models.model.Model``, required The model to evaluate instances : ``Iterable[Instance]``, required The (usually test) dataset on which to evalute the model. task_name : ``str``, required The name of the tasks on which evaluate the model. data_iterator : ``DataIterator`` Iterator that go through the dataset. cuda_device : ``int`` Cuda device to use. Returns ------- metrics : ``Dict[str, Any]`` A dictionary containing the metrics on the evaluated dataset. """ check_for_gpu(cuda_device) with torch.no_grad(): model.eval() iterator = data_iterator(instances, num_epochs=1, shuffle=False) logger.info("Iterating over dataset") generator_tqdm = tqdm.tqdm( iterator, total=data_iterator.get_num_batches(instances)) eval_loss = 0 nb_batches = 0 for tensor_batch in generator_tqdm: nb_batches += 1 train_stages = ["stm", "sd", "valid"] task_index = TASKS_NAME.index(task_name) tensor_batch['task_index'] = torch.tensor(task_index) tensor_batch["reverse"] = torch.tensor(False) tensor_batch['for_training'] = torch.tensor(False) train_stage = train_stages.index("stm") tensor_batch['train_stage'] = torch.tensor(train_stage) tensor_batch = move_to_device(tensor_batch, 0) eval_output_dict = model.forward(**tensor_batch) loss = eval_output_dict["loss"] eval_loss += loss.item() metrics = model.get_metrics(task_name=task_name) metrics["stm_loss"] = float(eval_loss / nb_batches) description = training_util.description_from_metrics(metrics) generator_tqdm.set_description(description, refresh=False) metrics = model.get_metrics(task_name=task_name, reset=True) metrics["stm_loss"] = float(eval_loss / nb_batches) return metrics
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, share_encoder: Seq2VecEncoder = None, private_encoder: Seq2VecEncoder = None, dropout: float = None, input_dropout: float = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: RegularizerApplicator = None) -> None: super(JointSentimentClassifier, self).__init__(vocab=vocab, regularizer=regularizer) self._text_field_embedder = text_field_embedder self._domain_embeddings = Embedding(len(TASKS_NAME), 50) if share_encoder is None and private_encoder is None: share_rnn = nn.LSTM(input_size=self._text_field_embedder.get_output_dim(), hidden_size=150, batch_first=True, dropout=dropout, bidirectional=True) share_encoder = PytorchSeq2SeqWrapper(share_rnn) private_rnn = nn.LSTM(input_size=self._text_field_embedder.get_output_dim(), hidden_size=150, batch_first=True, dropout=dropout, bidirectional=True) private_encoder = PytorchSeq2SeqWrapper(private_rnn) logger.info("Using LSTM as encoder") self._domain_embeddings = Embedding(len(TASKS_NAME), self._text_field_embedder.get_output_dim()) self._share_encoder = share_encoder self._s_domain_discriminator = Discriminator(share_encoder.get_output_dim(), len(TASKS_NAME)) self._p_domain_discriminator = Discriminator(private_encoder.get_output_dim(), len(TASKS_NAME)) # TODO individual valid discriminator self._valid_discriminator = Discriminator(self._domain_embeddings.get_output_dim(), 2) for task in TASKS_NAME: tagger = SentimentClassifier( vocab=vocab, text_field_embedder=self._text_field_embedder, share_encoder=self._share_encoder, private_encoder=copy.deepcopy(private_encoder), domain_embeddings=self._domain_embeddings, s_domain_discriminator=self._s_domain_discriminator, p_domain_discriminator=self._p_domain_discriminator, valid_discriminator=self._valid_discriminator, dropout=dropout, input_dropout=input_dropout, label_smoothing=0.1, initializer=initializer ) self.add_module("_tagger_{}".format(task), tagger) logger.info("Multi-Task Learning Model has been instantiated.")
def forward(self, task_index: torch.IntTensor, tokens: Dict[str, torch.LongTensor], epoch_trained: torch.IntTensor, valid_discriminator: Discriminator, reverse: torch.ByteTensor, for_training: torch.ByteTensor) -> Dict[str, torch.Tensor]: embedded_text_input = self._text_field_embedder(tokens) tokens_mask = util.get_text_field_mask(tokens) batch_size = get_batch_size(tokens) # TODO if np.random.rand() < -1 and for_training.all(): logger.info("Domain Embedding with Perturbation") domain_embeddings = self._domain_embeddings( torch.arange(0, len(TASKS_NAME)).cuda()) domain_embedding = get_perturbation_domain_embedding( domain_embeddings, task_index, epoch_trained) # domain_embedding = FGSM(self._domain_embeddings, task_index, valid_discriminator) output_dict = {"valid": torch.tensor(0)} else: logger.info("Domain Embedding without Perturbation") domain_embedding = self._domain_embeddings(task_index) output_dict = {"valid": torch.tensor(1)} output_dict["domain_embedding"] = domain_embedding embedded_text_input = self._input_dropout(embedded_text_input) if self._with_domain_embedding: domain_embedding = domain_embedding.expand(batch_size, 1, -1) embedded_text_input = torch.cat( (domain_embedding, embedded_text_input), 1) tokens_mask = torch.cat( [tokens_mask.new_ones(batch_size, 1), tokens_mask], 1) shared_encoded_text = self._shared_encoder(embedded_text_input, tokens_mask) # shared_encoded_text = self._seq2vec(shared_encoded_text, tokens_mask) shared_encoded_text = get_final_encoder_states(shared_encoded_text, tokens_mask, bidirectional=True) output_dict["share_embedding"] = shared_encoded_text private_encoded_text = self._private_encoder(embedded_text_input, tokens_mask) # private_encoded_text = self._seq2vec(private_encoded_text) private_encoded_text = get_final_encoder_states(private_encoded_text, tokens_mask, bidirectional=True) output_dict["private_embedding"] = private_encoded_text embedded_text = torch.cat([shared_encoded_text, private_encoded_text], -1) output_dict["embedded_text"] = embedded_text return output_dict
def _save_checkpoint(self, epoch: int, should_stop: bool, is_best: bool = False) -> None: ### Saving training state ### training_state = { "epoch": epoch, "should_stop": should_stop, "metric_infos": self._metric_infos, "task_infos": self._task_infos, "schedulers": {}, "optimizers": {}, } if self._optimizers is not None: for task_name, optimizers in self._optimizers.items(): training_state["optimizers"][task_name] = {} for params_name, optimizer in optimizers.items(): training_state["optimizers"][task_name][params_name] = optimizer.state_dict() if self._schedulers is not None: for task_name, scheduler in self._schedulers.items(): training_state["schedulers"][task_name] = scheduler.lr_scheduler.state_dict() training_path = os.path.join(self._serialization_dir, "training_state.th") torch.save(training_state, training_path) logger.info("Checkpoint - Saved training state to {}", training_path) ### Saving model state ### model_path = os.path.join(self._serialization_dir, "model_state.th") model_state = self._model.state_dict() torch.save(model_state, model_path) logger.info("Checkpoint - Saved model state to {}", model_path) if is_best: logger.info("Checkpoint - Best validation performance so far for all tasks") logger.info("Checkpoint - Copying weights to '{}/best_all.th'.", self._serialization_dir) shutil.copyfile(model_path, os.path.join(self._serialization_dir, "best_all.th")) ### Saving best models for each tasks ### for task_name, infos in self._metric_infos.items(): best_epoch, _ = infos["best"] if best_epoch == epoch: logger.info("Checkpoint - Best validation performance so far for {} tasks", task_name) logger.info("Checkpoint - Copying weights to '{}/best_{}.th'.", self._serialization_dir, task_name) shutil.copyfile(model_path, os.path.join(self._serialization_dir, "best_{}.th".format(task_name)))
def train(self, recover: bool = False) -> Dict[str, Any]: # 1 train sentiment classifier & private classifier & domain embeddings => init G 50 epoch # 2 fix share encoder(+domain embeddings?), train share classifier(cls&real/fake) & others => train D # 3 fix share classifier, train share encoder, reverse share classifier input gradient min loss => train G training_start_time = time.time() if recover: try: n_epoch, should_stop = self._restore_checkpoint() logger.info("Loaded model from checkpoint. Starting at epoch {}", n_epoch) except RuntimeError: raise ConfigurationError( "Could not recover training from the checkpoint. Did you mean to output to " "a different serialization directory or delete the existing serialization " "directory?" ) else: n_epoch, should_stop = 0, False ### Store all the necessary informations and attributes about the tasks ### task_infos = {task._name: {} for task in self._task_list} for task_idx, task in enumerate(self._task_list): task_info = task_infos[task._name] # Store statistiscs on training and validation batches data_iterator = task._data_iterator n_tr_batches = data_iterator.get_num_batches(task._train_data) n_val_batches = data_iterator.get_num_batches(task._validation_data) task_info["n_tr_batches"] = n_tr_batches task_info["n_val_batches"] = n_val_batches # Create counter for number of batches trained during the whole # training for this specific tasks task_info["total_n_batches_trained"] = 0 task_info["last_log"] = time.time() # Time of last logging self._task_infos = task_infos ### Bookkeeping the validation metrics ### metric_infos = { task._name: { "val_metric": task._val_metric, "hist": [], "is_out_of_patience": False, "min_lr_hit": False, "best": (-1, {}), } for task in self._task_list } self._metric_infos = metric_infos ### Write log ### total_n_tr_batches = 0 # The total number of training batches across all the datasets. for task_name, info in self._task_infos.items(): total_n_tr_batches += info["n_tr_batches"] logger.info("Task {}:", task_name) logger.info("\t{} training batches", info["n_tr_batches"]) logger.info("\t{} validation batches", info["n_val_batches"]) ### Create the training generators/iterators tqdm ### self._tr_generators = {} for task in self._task_list: data_iterator = task._data_iterator tr_generator = data_iterator(task._train_data, num_epochs=None) self._tr_generators[task._name] = tr_generator ### Create sampling probability distribution ### if self._sampling_method == "uniform": sampling_prob = [float(1 / self._n_tasks)] * self._n_tasks elif self._sampling_method == "proportional": sampling_prob = [float(info["n_tr_batches"] / total_n_tr_batches) for info in self._task_infos.values()] ### Enable gradient clipping ### # Only if self._grad_clipping is specified self._enable_gradient_clipping() ### Setup is ready. Training of the model can begin ### logger.info("Set up ready. Beginning training/validation.") avg_accuracies = [] best_accuracy = 0.0 ### Begin Training of the model ### while not should_stop: ### Log Infos: current epoch count and CPU/GPU usage ### logger.info("") logger.info("Epoch {}/{} - Begin", n_epoch, self._num_epochs - 1) logger.info(f"Peak CPU memory usage MB: {peak_memory_mb()}") for gpu, memory in gpu_memory_mb().items(): logger.info(f"GPU {gpu} memory usage MB: {memory}") # if n_epoch <= 10: # # init generator # all_tr_metrics = self._train_epoch(total_n_tr_batches, sampling_prob) # # train discriminator 3 epochs # # elif 10 < n_epoch < 20 or n_epoch % 2 == 0: # # all_tr_metrics = self._train_epoch(total_n_tr_batches, sampling_prob, train_D=True) # else: # train adversarial generator every 3 epoch all_tr_metrics = self._train_epoch(total_n_tr_batches, sampling_prob, reverse=True) all_val_metrics, avg_accuracy = self._validation(n_epoch) is_best = False if best_accuracy < avg_accuracy: best_accuracy = avg_accuracy logger.info("Best accuracy found --- {}", best_accuracy / self._n_tasks) is_best = True ### Print all training and validation metrics for this epoch ### logger.info("***** Epoch {}/{} Statistics *****", n_epoch, self._num_epochs - 1) for task in self._task_list: logger.info("Statistic: {}", task._name) logger.info( "\tTraining - {}: {:3d}", "Nb batches trained", self._task_infos[task._name]["n_batches_trained_this_epoch"], ) for metric_name, value in all_tr_metrics[task._name].items(): logger.info("\tTraining - {}: {:.3f}", metric_name, value) for metric_name, value in all_val_metrics[task._name].items(): logger.info("\tValidation - {}: {:.3f}", metric_name, value) logger.info("***** Average accuracy is {:.6f} *****", avg_accuracy / self._n_tasks) avg_accuracies.append(avg_accuracy / self._n_tasks) logger.info("**********") ### Check to see if should stop ### stop_tr, stop_val = True, True for task in self._task_list: # task_info = self._task_infos[tasks._name] if self._optimizers[task._name]['exclude_share_encoder'].param_groups[0]["lr"] < self._min_lr and \ self._optimizers[task._name]['exclude_share_discriminator'].param_groups[0][ "lr"] < self._min_lr: logger.info("Minimum lr hit on {}.", task._name) logger.info("Task {} vote to stop training.", task._name) metric_infos[task._name]["min_lr_hit"] = True stop_tr = stop_tr and self._metric_infos[task._name]["min_lr_hit"] stop_val = stop_val and self._metric_infos[task._name]["is_out_of_patience"] if stop_tr: should_stop = True logger.info("All tasks hit minimum lr. Stopping training.") if stop_val: should_stop = True logger.info("All metrics ran out of patience. Stopping training.") if n_epoch >= self._num_epochs - 1: should_stop = True logger.info("Maximum number of epoch hit. Stopping training.") self._save_checkpoint(n_epoch, should_stop, is_best) ### Update n_epoch ### # One epoch = doing N (forward + backward) pass where N is the total number of training batches. n_epoch += 1 self._epoch_trained = n_epoch logger.info("Max accuracy is {:.6f}", max(avg_accuracies)) ### Summarize training at the end ### logger.info("***** Training is finished *****") logger.info("Stopped training after {} epochs", n_epoch) return_metrics = {} for task_name, task_info in self._task_infos.items(): nb_epoch_trained = int(task_info["total_n_batches_trained"] / task_info["n_tr_batches"]) logger.info( "Trained {} for {} batches ~= {} epochs", task_name, task_info["total_n_batches_trained"], nb_epoch_trained, ) return_metrics[task_name] = { "best_epoch": self._metric_infos[task_name]["best"][0], "nb_epoch_trained": nb_epoch_trained, "best_epoch_val_metrics": self._metric_infos[task_name]["best"][1], } training_elapsed_time = time.time() - training_start_time return_metrics["training_duration"] = time.strftime("%d:%H:%M:%S", time.gmtime(training_elapsed_time)) return_metrics["nb_epoch_trained"] = n_epoch return return_metrics
def _validation(self, n_epoch: int) -> Tuple[float, int]: ### Begin validation of the model ### logger.info("Validation - Begin") all_val_metrics = {} self._model.eval() # Set the model into evaluation mode avg_accuracy = 0.0 for task_idx, task in enumerate(self._task_list): logger.info("Validation - Task {}/{}: {}", task_idx + 1, self._n_tasks, task._name) val_loss = 0.0 n_batches_val_this_epoch_this_task = 0 n_val_batches = self._task_infos[task._name]["n_val_batches"] scheduler = self._schedulers[task._name] # Create tqdm generator for current tasks's validation data_iterator = task._data_iterator val_generator = data_iterator(task._validation_data, num_epochs=1, shuffle=False) val_generator_tqdm = tqdm.tqdm(val_generator, total=n_val_batches) # Iterate over each validation batch for this tasks for batch in val_generator_tqdm: n_batches_val_this_epoch_this_task += 1 # Get the loss val_output_dict = self._forward(batch, task=task, for_training=False) loss = val_output_dict["stm_loss"] val_loss += loss.item() del loss # Get metrics for all progress so far, update tqdm, display description task_metrics = self._get_metrics(task=task) task_metrics["loss"] = float(val_loss / n_batches_val_this_epoch_this_task) description = training_util.description_from_metrics(task_metrics) val_generator_tqdm.set_description(description) # Get tasks validation metrics and store them in all_val_metrics task_metrics = self._get_metrics(task=task, reset=True) if task._name not in all_val_metrics: all_val_metrics[task._name] = {} for name, value in task_metrics.items(): all_val_metrics[task._name][name] = value all_val_metrics[task._name]["loss"] = float(val_loss / n_batches_val_this_epoch_this_task) avg_accuracy += task_metrics["sentiment_acc"] # Tensorboard - Validation metrics for this epoch for metric_name, value in all_val_metrics[task._name].items(): self._tensorboard.add_validation_scalar( name="task_" + task._name + "/" + metric_name, value=value ) ### Perform a patience check and update the history of validation metric for this tasks ### this_epoch_val_metric = all_val_metrics[task._name][task._val_metric] metric_history = self._metric_infos[task._name]["hist"] metric_history.append(this_epoch_val_metric) is_best_so_far, out_of_patience = self._check_history( metric_history=metric_history, cur_score=this_epoch_val_metric, should_decrease=task._val_metric_decreases, ) if is_best_so_far: logger.info("Best model found for {}.", task._name) self._metric_infos[task._name]["best"] = (n_epoch, all_val_metrics) if out_of_patience and not self._metric_infos[task._name]["is_out_of_patience"]: self._metric_infos[task._name]["is_out_of_patience"] = True logger.info("Task {} is out of patience and vote to stop the training.", task._name) # The LRScheduler API is agnostic to whether your schedule requires a validation metric - # if it doesn't, the validation metric passed here is ignored. scheduler.step(this_epoch_val_metric, n_epoch) logger.info("Validation - End") return all_val_metrics, avg_accuracy
def _train_epoch(self, total_n_tr_batches: int, sampling_prob: List, reverse=False, train_D=False) -> Dict[ str, float]: self._model.train() # Set the model to "train" mode. if reverse: logger.info("Training Generator- Begin") elif not train_D: logger.info("Training Init Generator- Begin") if train_D: logger.info("Training Discriminator- Begin") logger.info("reverse is {}, train_D is {}", reverse, train_D) ### Reset training and trained batches counter before new training epoch ### for _, task_info in self._task_infos.items(): task_info["tr_loss_cum"] = 0.0 task_info['stm_loss'] = 0.0 task_info['p_d_loss'] = 0.0 task_info['s_d_loss'] = 0.0 task_info['valid_loss'] = 0.0 task_info["n_batches_trained_this_epoch"] = 0 all_tr_metrics = {} # BUG TO COMPLETE COMMENT TO MAKE IT MORE CLEAR ### Start training epoch ### epoch_tqdm = tqdm.tqdm(range(total_n_tr_batches), total=total_n_tr_batches) histogram_parameters = set(self._model.get_parameters_for_histogram_tensorboard_logging()) for step, _ in enumerate(epoch_tqdm): task_idx = np.argmax(np.random.multinomial(1, sampling_prob)) task = self._task_list[task_idx] task_info = self._task_infos[task._name] ### One forward + backward pass ### # Call next batch to train batch = next(self._tr_generators[task._name]) self._batch_num_total += 1 task_info["n_batches_trained_this_epoch"] += 1 # Load optimizer if not train_D: optimizer = self._optimizers[task._name]["all_params"] else: optimizer = self._optimizers[task._name]["exclude_share_encoder"] # Get the loss for this batch output_dict = self._forward(tensor_batch=batch, task=task, for_training=True, reverse=reverse) # if reverse or train_D: # output_dict_fake = self._forward(tensor_batch=batch, task=task, for_training=True, reverse=True) # loss = output_dict["stm_loss"] # if train_D: # loss = (output_dict["stm_loss"] + output_dict["s_d_loss"] + output_dict_fake["stm_loss"] + # output_dict_fake["s_d_loss"]) / 2.0 # if reverse: # # loss = (output_dict["stm_loss"] + output_dict["p_d_loss"] + 0.005 * output_dict["s_d_loss"] + # # output_dict_fake["stm_loss"] + output_dict_fake["p_d_loss"] + 0.005 * output_dict_fake[ # # "s_d_loss"]) / 2.0 # loss = (output_dict['loss'] + output_dict_fake['loss']) / 2.0 loss = output_dict['loss'] if self._gradient_accumulation_steps > 1: loss /= self._gradient_accumulation_steps loss.backward() task_info["tr_loss_cum"] += loss.item() task_info['stm_loss'] += output_dict['stm_loss'].item() task_info['p_d_loss'] += output_dict['p_d_loss'].item() task_info['s_d_loss'] += output_dict['s_d_loss'].item() task_info['valid_loss'] += output_dict['valid_loss'].item() # if reverse or train_D: # task_info['stm_loss'] += output_dict_fake['stm_loss'].item() # task_info['stm_loss'] /= 2.0 # task_info['p_d_loss'] += output_dict_fake['p_d_loss'].item() # task_info['p_d_loss'] /= 2.0 # task_info['s_d_loss'] += output_dict_fake['s_d_loss'].item() # task_info['s_d_loss'] /= 2.0 # task_info['valid_loss'] += output_dict_fake['valid_loss'].item() # task_info['valid_loss'] /= 2.0 del loss if (step + 1) % self._gradient_accumulation_steps == 0: batch_grad_norm = self._rescale_gradients() if self._tensorboard.should_log_histograms_this_batch(): param_updates = {name: param.detach().cpu().clone() for name, param in self._model.named_parameters()} optimizer.step() for name, param in self._model.named_parameters(): param_updates[name].sub_(param.detach().cpu()) update_norm = torch.norm(param_updates[name].view(-1, )) param_norm = torch.norm(param.view(-1, )).cpu() self._tensorboard.add_train_scalar("gradient_update/" + name, update_norm / (param_norm + 1e-7)) else: optimizer.step() optimizer.zero_grad() ### Get metrics for all progress so far, update tqdm, display description ### task_metrics = self._get_metrics(task=task) task_metrics["loss"] = float( task_info["tr_loss_cum"] / (task_info["n_batches_trained_this_epoch"] + 0.000_001) ) task_metrics["stm_loss"] = float( task_info["stm_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_001) ) task_metrics["p_d_loss"] = float( task_info["p_d_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_001) ) task_metrics["s_d_loss"] = float( task_info["s_d_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_001) ) task_metrics["valid_loss"] = float( task_info["valid_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_001) ) description = training_util.description_from_metrics(task_metrics) epoch_tqdm.set_description(task._name + ", " + description) # Log parameter values to Tensorboard if self._tensorboard.should_log_this_batch(): self._tensorboard.log_parameter_and_gradient_statistics(self._model, batch_grad_norm) self._tensorboard.log_learning_rates(self._model, optimizer) self._tensorboard.log_metrics( {"epoch_metrics/" + task._name + "/" + k: v for k, v in task_metrics.items()}) if self._tensorboard.should_log_histograms_this_batch(): self._tensorboard.log_histograms(self._model, histogram_parameters) self._global_step += 1 ### Bookkeeping all the training metrics for all the tasks on the training epoch that just finished ### for task in self._task_list: task_info = self._task_infos[task._name] task_info["total_n_batches_trained"] += task_info["n_batches_trained_this_epoch"] task_info["last_log"] = time.time() task_metrics = self._get_metrics(task=task, reset=True) if task._name not in all_tr_metrics: all_tr_metrics[task._name] = {} for name, value in task_metrics.items(): all_tr_metrics[task._name][name] = value all_tr_metrics[task._name]["loss"] = float( task_info["tr_loss_cum"] / (task_info["n_batches_trained_this_epoch"] + 0.000_000_01) ) all_tr_metrics[task._name]["stm_loss"] = float( task_info["stm_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_000_01) ) all_tr_metrics[task._name]["p_d_loss"] = float( task_info["p_d_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_000_01) ) all_tr_metrics[task._name]["s_d_loss"] = float( task_info["s_d_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_000_01) ) all_tr_metrics[task._name]["valid_loss"] = float( task_info["valid_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_000_01) ) # Tensorboard - Training metrics for this epoch for metric_name, value in all_tr_metrics[task._name].items(): self._tensorboard.add_train_scalar( name="task_" + task._name + "/" + metric_name, value=value ) logger.info("Train - End") return all_tr_metrics
def tasks_and_vocab_from_params( params: Params, serialization_dir: str) -> Tuple[List[Task], Vocabulary]: """ Load each of the tasks in the model from the ``params`` file and load the datasets associated with each of these tasks. Create the vocavulary from ``params`` using the concatenation of the ``datasets_for_vocab_creation`` from each of the tasks specific dataset. Parameters ---------- params: ``Params`` A parameter object specifing an experiment. serialization_dir: ``str`` Directory in which to save the model and its logs. Returns ------- task_list: ``List[Task]`` A list containing the tasks of the model to train. vocab: ``Vocabulary`` The vocabulary fitted on the datasets_for_vocab_creation. """ ### Instantiate the different tasks ### task_list = [] instances_for_vocab_creation = itertools.chain() datasets_for_vocab_creation = {} task_keys = TASKS_NAME data_dir = "data/mtl-dataset" for key in task_keys: logger.info("Creating {}", key) task_data_params = Params({ "dataset_reader": { "type": "semantic_review" }, "train_data_path": os.path.join(data_dir, key + ".task.train"), "test_data_path": os.path.join(data_dir, key + ".task.test"), "validation_data_path": os.path.join(data_dir, key + ".task.test"), "datasets_for_vocab_creation": ["train", "validation", "test"] }) task_description = Params({ "task_name": key, "validation_metric_name": "{}_stm_acc".format(key) }) task = Task.from_params(params=task_description) task_list.append(task) task_instances_for_vocab, task_datasets_for_vocab = task.load_data_from_params( params=task_data_params) instances_for_vocab_creation = itertools.chain( instances_for_vocab_creation, task_instances_for_vocab) datasets_for_vocab_creation[task._name] = task_datasets_for_vocab ### Create and save the vocabulary ### for task_name, task_dataset_list in datasets_for_vocab_creation.items(): logger.info("Creating a vocabulary using {} data from {}.", ", ".join(task_dataset_list), task_name) logger.info("Fitting vocabulary from dataset") vocab = Vocabulary.from_params(params.pop("vocabulary", {}), instances_for_vocab_creation) vocab.save_to_files(os.path.join(serialization_dir, "vocabulary")) logger.info("Vocabulary saved to {}", os.path.join(serialization_dir, "vocabulary")) return task_list, vocab
def train_model(multi_task_trainer: TransferMtlTrainer, recover: bool = False) -> Dict[str, Any]: """ Launching the training of the multi-tasks model. Parameters ---------- multi_task_trainer: ``MultiTaskTrainer`` A trainer (similar to allennlp.training.trainer.Trainer) that can handle multi-tasks training. recover : ``bool``, optional (default=False) If ``True``, we will try to recover a training run from an existing serialization directory. This is only intended for use when something actually crashed during the middle of a run. For continuing training a model on new data, see the ``fine-tune`` command. Returns ------- metrics: ``Dict[str, Any] The different metrics summarizing the training of the model. It includes the validation and test (if necessary) metrics. """ ### Train the multi-tasks model ### metrics = multi_task_trainer.train(recover=recover) task_list = multi_task_trainer._task_list serialization_dir = multi_task_trainer._serialization_dir model = multi_task_trainer._model # Evaluate the model on test data. if necessary This is a multi-tasks learning framework, the best validation # metrics for one tasks are not necessarily obtained from the same epoch for all the tasks, one epoch begin equal # to N forward+backward passes, where N is the total number of batches in all the training sets. We evaluate each # of the best model for each tasks (based on the validation metrics) for all the other tasks (which have a test # set). avg_accuracies = [] for task in task_list: if not task._evaluate_on_test: continue logger.info("Task {} will be evaluated using the best epoch weights.", task._name) assert ( task._test_data is not None ), "Task {} wants to be evaluated on test dataset but no there is no test data loaded.".format( task._name) logger.info("Loading the best epoch weights for tasks {}", task._name) best_model_state_path = os.path.join(serialization_dir, "best_{}.th".format(task._name)) best_model_state = torch.load(best_model_state_path) best_model = model best_model.load_state_dict(state_dict=best_model_state) test_metric_dict = {} avg_accuracy = 0.0 for pair_task in task_list: if not pair_task._evaluate_on_test: continue logger.info( "Pair tasks {} is evaluated with the best model for {}", pair_task._name, task._name) test_metric_dict[pair_task._name] = {} test_metrics = evaluate( model=best_model, task_name=pair_task._name, instances=pair_task._test_data, data_iterator=pair_task._data_iterator, cuda_device=multi_task_trainer._cuda_device, ) for metric_name, value in test_metrics.items(): test_metric_dict[pair_task._name][metric_name] = value avg_accuracy += test_metrics["{}_stm_acc".format(pair_task._name)] logger.info("Average accuracy of task {} is {}", task._name, avg_accuracy / len(task_list)) avg_accuracies.append(avg_accuracy / len(task_list)) metrics[task._name]["test"] = deepcopy(test_metric_dict) logger.info("Finished evaluation of tasks {}.", task._name) ### Dump validation and possibly test metrics ### metrics_json = json.dumps(metrics, indent=2) with open(os.path.join(serialization_dir, "metrics.json"), "w") as metrics_file: metrics_file.write(metrics_json) logger.info("Metrics: {}", metrics_json) logger.info("Average accuracy is {}", sum(avg_accuracies) / len(avg_accuracies)) return metrics
tasks, vocab = tasks_and_vocab_from_params( params=params, serialization_dir=serialization_dir) ### Load the data iterators for each tasks ### tasks = create_and_set_iterators(params=params, task_list=tasks, vocab=vocab) ### Load Regularizations ### regularizer = RegularizerApplicator.from_params( params.pop("regularizer", [])) ### Create model ### model_params = params.pop("model") model = Model.from_params(vocab=vocab, params=model_params, regularizer=regularizer) ### Create multi-tasks trainer ### multi_task_trainer_params = params.pop("multi_task_trainer") trainer = TransferMtlTrainer.from_params( model=model, task_list=tasks, serialization_dir=serialization_dir, params=multi_task_trainer_params) ### Launch training ### metrics = train_model(multi_task_trainer=trainer, recover=args.recover) if metrics is not None: logger.info( "Training is finished ! Let's have a drink. It's on the house !")
action="store_true", required=False, default=False, help="Whether or not evaluate using gold mentions in coreference", ) args = parser.parse_args() params = Params.from_file( params_file=os.path.join(args.serialization_dir, "config.json")) ### Instantiate tasks ### task_list = [] task_keys = [key for key in params.keys() if re.search("^task_", key)] for key in task_keys: logger.info("Creating {}", key) task_params = params.pop(key) task_description = task_params.pop("task_description") task_data_params = task_params.pop("data_params") task = Task.from_params(params=task_description) task_list.append(task) _, _ = task.load_data_from_params(params=task_data_params) ### Load Vocabulary from files ### vocab = Vocabulary.from_files( os.path.join(args.serialization_dir, "vocabulary")) logger.info("Vocabulary loaded") ### Load the data iterators ###
def _train_epoch(self, total_n_tr_batches: int, sampling_prob: List) -> Dict[str, float]: self._model.train() # Set the model to "train" mode. ### Reset training and trained batches counter before new training epoch ### for _, task_info in self._task_infos.items(): task_info["tr_loss_cum"] = 0.0 task_info['stm_loss'] = 0.0 task_info['s_d_loss'] = 0.0 task_info['valid_loss'] = 0.0 task_info["n_batches_trained_this_epoch"] = 0 all_tr_metrics = {} # BUG TO COMPLETE COMMENT TO MAKE IT MORE CLEAR ### Start training epoch ### epoch_tqdm = tqdm.tqdm(range(total_n_tr_batches), total=total_n_tr_batches) histogram_parameters = set( self._model.get_parameters_for_histogram_tensorboard_logging()) for step, _ in enumerate(epoch_tqdm): task_idx = np.argmax(np.random.multinomial(1, sampling_prob)) task = self._task_list[task_idx] task_info = self._task_infos[task._name] ### One forward + backward pass ### # Call next batch to train batch = next(self._tr_generators[task._name]) self._batch_num_total += 1 task_info["n_batches_trained_this_epoch"] += 1 # ------------------------------------- # train sentiment classify # ------------------------------------- # Load optimizer optimizer = self._optimizers["all_params"] # Get the loss for this batch output_dict = self._forward(tensor_batch=batch, task=task, for_training=True, train_stage="stm") loss = output_dict['loss'] task_info['stm_loss'] += loss.item() loss.backward() del loss self._log_params_update(optimizer) # ------------------------------------- # train domain classify # ------------------------------------- # optimizer = self._optimizers["share_discriminator"] # # # Get the loss for this batch # output_dict = self._forward(tensor_batch=batch, task=task, for_training=True, train_stage="sd") # # loss = output_dict['loss'] # task_info['s_d_loss'] += loss.item() # loss.backward() # del loss # # self._log_params_update(optimizer) # ------------------------------------- # train adversarial domain classify # ------------------------------------- optimizer = self._optimizers["all_params"] # Get the loss for this batch output_dict = self._forward(tensor_batch=batch, task=task, for_training=True, reverse=True, train_stage="sd") loss = output_dict['loss'] loss = 0.05 * loss task_info['s_d_loss'] += loss.item() loss.backward() del loss self._log_params_update(optimizer) # ------------------------------------- # train valid classify # ------------------------------------- # optimizer = self._optimizers["valid_discriminator"] # # # Get the loss for this batch # output_dict = self._forward(tensor_batch=batch, task=task, for_training=True, train_stage="valid") # # loss = output_dict['loss'] # task_info['valid_loss'] += loss.item() # loss.backward() # del loss # # self._log_params_update(optimizer) # ------------------------------------- # train adversarial valid classify # ------------------------------------- optimizer = self._optimizers["all_params"] # Get the loss for this batch output_dict = self._forward(tensor_batch=batch, task=task, for_training=True, reverse=True, train_stage="valid") loss = output_dict['loss'] loss = 0.05 * loss task_info['valid_loss'] += loss.item() loss.backward() del loss self._log_params_update(optimizer) ### Get metrics for all progress so far, update tqdm, display description ### task_metrics = self._model.get_metrics(task_name=task._name) task_metrics["stm_loss"] = float( task_info["stm_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_001)) task_metrics["s_d_loss"] = float( task_info["s_d_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_001)) task_metrics["valid_loss"] = float( task_info["valid_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_001)) description = training_util.description_from_metrics(task_metrics) epoch_tqdm.set_description(description) self._log_params_values(optimizer, task._name, task_metrics, histogram_parameters) self._global_step += 1 ### Bookkeeping all the training metrics for all the tasks on the training epoch that just finished ### for task in self._task_list: task_info = self._task_infos[task._name] task_info["total_n_batches_trained"] += task_info[ "n_batches_trained_this_epoch"] task_info["last_log"] = time.time() task_metrics = self._model.get_metrics(task_name=task._name, reset=True) if task._name not in all_tr_metrics: all_tr_metrics[task._name] = {} for name, value in task_metrics.items(): all_tr_metrics[task._name][name] = value all_tr_metrics[task._name]["stm_loss"] = float( task_info["stm_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_000_01)) all_tr_metrics[task._name]["s_d_loss"] = float( task_info["s_d_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_000_01)) all_tr_metrics[task._name]["valid_loss"] = float( task_info["valid_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_000_01)) # Tensorboard - Training metrics for this epoch for metric_name, value in all_tr_metrics[task._name].items(): self._tensorboard.add_train_scalar(name="task_" + task._name + "/" + metric_name, value=value) logger.info("Train - End") return all_tr_metrics
def forward( self, # type: ignore task_index: torch.IntTensor, reverse: torch.ByteTensor, epoch_trained: torch.IntTensor, for_training: torch.ByteTensor, tokens: Dict[str, torch.LongTensor], label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: embeddeds = self._encoder(task_index, tokens, epoch_trained, self._valid_discriminator, reverse, for_training) batch_size = get_batch_size(embeddeds["embedded_text"]) sentiment_logits = self._sentiment_discriminator( embeddeds["embedded_text"]) p_domain_logits = self._p_domain_discriminator( embeddeds["private_embedding"]) # TODO set reverse = true s_domain_logits = self._s_domain_discriminator( embeddeds["share_embedding"], reverse=reverse) # TODO set reverse = true # TODO use share_embedding instead of domain_embedding valid_logits = self._valid_discriminator(embeddeds["domain_embedding"], reverse=reverse) valid_label = embeddeds['valid'] logits = [ sentiment_logits, p_domain_logits, s_domain_logits, valid_logits ] # domain_logits = self._domain_discriminator(embedded_text) output_dict = {'logits': sentiment_logits} if label is not None: loss = self._loss(sentiment_logits, label) # task_index = task_index.unsqueeze(0) task_index = task_index.expand(batch_size) targets = [label, task_index, task_index, valid_label] # print(p_domain_logits.shape, task_index, task_index.shape) p_domain_loss = self._domain_loss(p_domain_logits, task_index) s_domain_loss = self._domain_loss(s_domain_logits, task_index) logger.info( "Share domain logits standard variation is {}", torch.mean(torch.std(F.softmax(s_domain_logits), dim=-1))) if self._label_smoothing is not None and self._label_smoothing > 0.0: valid_loss = sequence_cross_entropy_with_logits( valid_logits, valid_label.unsqueeze(0).cuda(), torch.tensor(1).unsqueeze(0).cuda(), average="token", label_smoothing=self._label_smoothing) else: valid_loss = self._valid_loss( valid_logits, torch.zeros(2).scatter_(0, valid_label, torch.tensor(1.0)).cuda()) output_dict['stm_loss'] = loss output_dict['p_d_loss'] = p_domain_loss output_dict['s_d_loss'] = s_domain_loss output_dict['valid_loss'] = valid_loss # TODO add share domain logits std loss output_dict['loss'] = loss + p_domain_loss + 0.005 * s_domain_loss\ # + 0.005 * valid_loss # + torch.mean(torch.std(s_domain_logits, dim=1)) # output_dict['loss'] = loss + p_domain_loss + 0.005 * s_domain_loss for (metric, logit, target) in zip(self.metrics.values(), logits, targets): metric(logit, target) return output_dict