Пример #1
0
def evaluate(model: Model, instances: Iterable[Instance], task_name: str,
             data_iterator: DataIterator, cuda_device: int) -> Dict[str, Any]:
    """
    Evaluate a model for a particular tasks (usually after training).
    
    Parameters
    ----------
    model : ``allennlp.models.model.Model``, required
        The model to evaluate
    instances : ``Iterable[Instance]``, required
        The (usually test) dataset on which to evalute the model.
    task_name : ``str``, required
        The name of the tasks on which evaluate the model.
    data_iterator : ``DataIterator``
        Iterator that go through the dataset.
    cuda_device : ``int``
        Cuda device to use.
        
    Returns
    -------
    metrics :  ``Dict[str, Any]``
        A dictionary containing the metrics on the evaluated dataset.
    """
    check_for_gpu(cuda_device)
    with torch.no_grad():
        model.eval()

        iterator = data_iterator(instances, num_epochs=1, shuffle=False)
        logger.info("Iterating over dataset")
        generator_tqdm = tqdm.tqdm(
            iterator, total=data_iterator.get_num_batches(instances))

        eval_loss = 0
        nb_batches = 0
        for tensor_batch in generator_tqdm:
            nb_batches += 1

            train_stages = ["stm", "sd", "valid"]
            task_index = TASKS_NAME.index(task_name)
            tensor_batch['task_index'] = torch.tensor(task_index)
            tensor_batch["reverse"] = torch.tensor(False)
            tensor_batch['for_training'] = torch.tensor(False)
            train_stage = train_stages.index("stm")
            tensor_batch['train_stage'] = torch.tensor(train_stage)
            tensor_batch = move_to_device(tensor_batch, 0)

            eval_output_dict = model.forward(**tensor_batch)
            loss = eval_output_dict["loss"]
            eval_loss += loss.item()
            metrics = model.get_metrics(task_name=task_name)
            metrics["stm_loss"] = float(eval_loss / nb_batches)

            description = training_util.description_from_metrics(metrics)
            generator_tqdm.set_description(description, refresh=False)

        metrics = model.get_metrics(task_name=task_name, reset=True)
        metrics["stm_loss"] = float(eval_loss / nb_batches)
        return metrics
Пример #2
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 share_encoder: Seq2VecEncoder = None,
                 private_encoder: Seq2VecEncoder = None,
                 dropout: float = None,
                 input_dropout: float = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: RegularizerApplicator = None) -> None:
        super(JointSentimentClassifier, self).__init__(vocab=vocab, regularizer=regularizer)

        self._text_field_embedder = text_field_embedder
        self._domain_embeddings = Embedding(len(TASKS_NAME), 50)
        if share_encoder is None and private_encoder is None:
            share_rnn = nn.LSTM(input_size=self._text_field_embedder.get_output_dim(),
                                hidden_size=150,
                                batch_first=True,
                                dropout=dropout,
                                bidirectional=True)
            share_encoder = PytorchSeq2SeqWrapper(share_rnn)
            private_rnn = nn.LSTM(input_size=self._text_field_embedder.get_output_dim(),
                                  hidden_size=150,
                                  batch_first=True,
                                  dropout=dropout,
                                  bidirectional=True)
            private_encoder = PytorchSeq2SeqWrapper(private_rnn)
            logger.info("Using LSTM as encoder")
            self._domain_embeddings = Embedding(len(TASKS_NAME), self._text_field_embedder.get_output_dim())
        self._share_encoder = share_encoder

        self._s_domain_discriminator = Discriminator(share_encoder.get_output_dim(), len(TASKS_NAME))

        self._p_domain_discriminator = Discriminator(private_encoder.get_output_dim(), len(TASKS_NAME))

        # TODO individual valid discriminator
        self._valid_discriminator = Discriminator(self._domain_embeddings.get_output_dim(), 2)

        for task in TASKS_NAME:
            tagger = SentimentClassifier(
                vocab=vocab,
                text_field_embedder=self._text_field_embedder,
                share_encoder=self._share_encoder,
                private_encoder=copy.deepcopy(private_encoder),
                domain_embeddings=self._domain_embeddings,
                s_domain_discriminator=self._s_domain_discriminator,
                p_domain_discriminator=self._p_domain_discriminator,
                valid_discriminator=self._valid_discriminator,
                dropout=dropout,
                input_dropout=input_dropout,
                label_smoothing=0.1,
                initializer=initializer
            )
            self.add_module("_tagger_{}".format(task), tagger)

        logger.info("Multi-Task Learning Model has been instantiated.")
Пример #3
0
    def forward(self, task_index: torch.IntTensor,
                tokens: Dict[str,
                             torch.LongTensor], epoch_trained: torch.IntTensor,
                valid_discriminator: Discriminator, reverse: torch.ByteTensor,
                for_training: torch.ByteTensor) -> Dict[str, torch.Tensor]:
        embedded_text_input = self._text_field_embedder(tokens)
        tokens_mask = util.get_text_field_mask(tokens)
        batch_size = get_batch_size(tokens)
        # TODO
        if np.random.rand() < -1 and for_training.all():
            logger.info("Domain Embedding with Perturbation")
            domain_embeddings = self._domain_embeddings(
                torch.arange(0, len(TASKS_NAME)).cuda())
            domain_embedding = get_perturbation_domain_embedding(
                domain_embeddings, task_index, epoch_trained)
            # domain_embedding = FGSM(self._domain_embeddings, task_index, valid_discriminator)
            output_dict = {"valid": torch.tensor(0)}
        else:
            logger.info("Domain Embedding without Perturbation")
            domain_embedding = self._domain_embeddings(task_index)
            output_dict = {"valid": torch.tensor(1)}

        output_dict["domain_embedding"] = domain_embedding
        embedded_text_input = self._input_dropout(embedded_text_input)
        if self._with_domain_embedding:
            domain_embedding = domain_embedding.expand(batch_size, 1, -1)
            embedded_text_input = torch.cat(
                (domain_embedding, embedded_text_input), 1)
            tokens_mask = torch.cat(
                [tokens_mask.new_ones(batch_size, 1), tokens_mask], 1)

        shared_encoded_text = self._shared_encoder(embedded_text_input,
                                                   tokens_mask)
        # shared_encoded_text = self._seq2vec(shared_encoded_text, tokens_mask)
        shared_encoded_text = get_final_encoder_states(shared_encoded_text,
                                                       tokens_mask,
                                                       bidirectional=True)
        output_dict["share_embedding"] = shared_encoded_text

        private_encoded_text = self._private_encoder(embedded_text_input,
                                                     tokens_mask)
        # private_encoded_text = self._seq2vec(private_encoded_text)
        private_encoded_text = get_final_encoder_states(private_encoded_text,
                                                        tokens_mask,
                                                        bidirectional=True)
        output_dict["private_embedding"] = private_encoded_text

        embedded_text = torch.cat([shared_encoded_text, private_encoded_text],
                                  -1)
        output_dict["embedded_text"] = embedded_text
        return output_dict
Пример #4
0
    def _save_checkpoint(self, epoch: int, should_stop: bool, is_best: bool = False) -> None:
        ### Saving training state ###
        training_state = {
            "epoch": epoch,
            "should_stop": should_stop,
            "metric_infos": self._metric_infos,
            "task_infos": self._task_infos,
            "schedulers": {},
            "optimizers": {},
        }

        if self._optimizers is not None:
            for task_name, optimizers in self._optimizers.items():
                training_state["optimizers"][task_name] = {}
                for params_name, optimizer in optimizers.items():
                    training_state["optimizers"][task_name][params_name] = optimizer.state_dict()
        if self._schedulers is not None:
            for task_name, scheduler in self._schedulers.items():
                training_state["schedulers"][task_name] = scheduler.lr_scheduler.state_dict()

        training_path = os.path.join(self._serialization_dir, "training_state.th")
        torch.save(training_state, training_path)
        logger.info("Checkpoint - Saved training state to {}", training_path)

        ### Saving model state ###
        model_path = os.path.join(self._serialization_dir, "model_state.th")
        model_state = self._model.state_dict()
        torch.save(model_state, model_path)
        logger.info("Checkpoint - Saved model state to {}", model_path)

        if is_best:
            logger.info("Checkpoint - Best validation performance so far for all tasks")
            logger.info("Checkpoint - Copying weights to '{}/best_all.th'.", self._serialization_dir)
            shutil.copyfile(model_path, os.path.join(self._serialization_dir, "best_all.th"))

        ### Saving best models for each tasks ###
        for task_name, infos in self._metric_infos.items():
            best_epoch, _ = infos["best"]
            if best_epoch == epoch:
                logger.info("Checkpoint - Best validation performance so far for {} tasks", task_name)
                logger.info("Checkpoint - Copying weights to '{}/best_{}.th'.", self._serialization_dir, task_name)
                shutil.copyfile(model_path, os.path.join(self._serialization_dir, "best_{}.th".format(task_name)))
Пример #5
0
    def train(self, recover: bool = False) -> Dict[str, Any]:

        # 1 train sentiment classifier & private classifier & domain embeddings => init G 50 epoch
        # 2 fix share encoder(+domain embeddings?), train share classifier(cls&real/fake) & others => train D
        # 3 fix share classifier, train share encoder, reverse share classifier input gradient  min loss => train G
        training_start_time = time.time()

        if recover:
            try:
                n_epoch, should_stop = self._restore_checkpoint()
                logger.info("Loaded model from checkpoint. Starting at epoch {}", n_epoch)
            except RuntimeError:
                raise ConfigurationError(
                    "Could not recover training from the checkpoint.  Did you mean to output to "
                    "a different serialization directory or delete the existing serialization "
                    "directory?"
                )
        else:
            n_epoch, should_stop = 0, False

            ### Store all the necessary informations and attributes about the tasks ###
            task_infos = {task._name: {} for task in self._task_list}
            for task_idx, task in enumerate(self._task_list):
                task_info = task_infos[task._name]

                # Store statistiscs on training and validation batches
                data_iterator = task._data_iterator
                n_tr_batches = data_iterator.get_num_batches(task._train_data)
                n_val_batches = data_iterator.get_num_batches(task._validation_data)
                task_info["n_tr_batches"] = n_tr_batches
                task_info["n_val_batches"] = n_val_batches

                # Create counter for number of batches trained during the whole
                # training for this specific tasks
                task_info["total_n_batches_trained"] = 0

                task_info["last_log"] = time.time()  # Time of last logging
            self._task_infos = task_infos

            ### Bookkeeping the validation metrics ###
            metric_infos = {
                task._name: {
                    "val_metric": task._val_metric,
                    "hist": [],
                    "is_out_of_patience": False,
                    "min_lr_hit": False,
                    "best": (-1, {}),
                }
                for task in self._task_list
            }
            self._metric_infos = metric_infos

        ### Write log ###
        total_n_tr_batches = 0  # The total number of training batches across all the datasets.
        for task_name, info in self._task_infos.items():
            total_n_tr_batches += info["n_tr_batches"]
            logger.info("Task {}:", task_name)
            logger.info("\t{} training batches", info["n_tr_batches"])
            logger.info("\t{} validation batches", info["n_val_batches"])

        ### Create the training generators/iterators tqdm ###
        self._tr_generators = {}
        for task in self._task_list:
            data_iterator = task._data_iterator
            tr_generator = data_iterator(task._train_data, num_epochs=None)
            self._tr_generators[task._name] = tr_generator

        ### Create sampling probability distribution ###
        if self._sampling_method == "uniform":
            sampling_prob = [float(1 / self._n_tasks)] * self._n_tasks
        elif self._sampling_method == "proportional":
            sampling_prob = [float(info["n_tr_batches"] / total_n_tr_batches) for info in self._task_infos.values()]

        ### Enable gradient clipping ###
        # Only if self._grad_clipping is specified
        self._enable_gradient_clipping()

        ### Setup is ready. Training of the model can begin ###
        logger.info("Set up ready. Beginning training/validation.")

        avg_accuracies = []
        best_accuracy = 0.0

        ### Begin Training of the model ###
        while not should_stop:
            ### Log Infos: current epoch count and CPU/GPU usage ###
            logger.info("")
            logger.info("Epoch {}/{} - Begin", n_epoch, self._num_epochs - 1)
            logger.info(f"Peak CPU memory usage MB: {peak_memory_mb()}")
            for gpu, memory in gpu_memory_mb().items():
                logger.info(f"GPU {gpu} memory usage MB: {memory}")

            # if n_epoch <= 10:
            #     # init generator
            #     all_tr_metrics = self._train_epoch(total_n_tr_batches, sampling_prob)
            # # train discriminator 3 epochs
            # # elif 10 < n_epoch < 20 or n_epoch % 2 == 0:
            # #     all_tr_metrics = self._train_epoch(total_n_tr_batches, sampling_prob, train_D=True)
            # else:
            # train adversarial generator every 3 epoch
            all_tr_metrics = self._train_epoch(total_n_tr_batches, sampling_prob, reverse=True)

            all_val_metrics, avg_accuracy = self._validation(n_epoch)
            is_best = False
            if best_accuracy < avg_accuracy:
                best_accuracy = avg_accuracy
                logger.info("Best accuracy found --- {}", best_accuracy / self._n_tasks)
                is_best = True

            ### Print all training and validation metrics for this epoch ###
            logger.info("***** Epoch {}/{} Statistics *****", n_epoch, self._num_epochs - 1)
            for task in self._task_list:
                logger.info("Statistic: {}", task._name)
                logger.info(
                    "\tTraining - {}: {:3d}",
                    "Nb batches trained",
                    self._task_infos[task._name]["n_batches_trained_this_epoch"],
                )
                for metric_name, value in all_tr_metrics[task._name].items():
                    logger.info("\tTraining - {}: {:.3f}", metric_name, value)
                for metric_name, value in all_val_metrics[task._name].items():
                    logger.info("\tValidation - {}: {:.3f}", metric_name, value)
            logger.info("***** Average accuracy is {:.6f} *****", avg_accuracy / self._n_tasks)
            avg_accuracies.append(avg_accuracy / self._n_tasks)
            logger.info("**********")

            ### Check to see if should stop ###
            stop_tr, stop_val = True, True

            for task in self._task_list:
                # task_info = self._task_infos[tasks._name]
                if self._optimizers[task._name]['exclude_share_encoder'].param_groups[0]["lr"] < self._min_lr and \
                        self._optimizers[task._name]['exclude_share_discriminator'].param_groups[0][
                            "lr"] < self._min_lr:
                    logger.info("Minimum lr hit on {}.", task._name)
                    logger.info("Task {} vote to stop training.", task._name)
                    metric_infos[task._name]["min_lr_hit"] = True
                stop_tr = stop_tr and self._metric_infos[task._name]["min_lr_hit"]
                stop_val = stop_val and self._metric_infos[task._name]["is_out_of_patience"]

            if stop_tr:
                should_stop = True
                logger.info("All tasks hit minimum lr. Stopping training.")
            if stop_val:
                should_stop = True
                logger.info("All metrics ran out of patience. Stopping training.")
            if n_epoch >= self._num_epochs - 1:
                should_stop = True
                logger.info("Maximum number of epoch hit. Stopping training.")

            self._save_checkpoint(n_epoch, should_stop, is_best)

            ### Update n_epoch ###
            # One epoch = doing N (forward + backward) pass where N is the total number of training batches.
            n_epoch += 1
            self._epoch_trained = n_epoch

        logger.info("Max accuracy is {:.6f}", max(avg_accuracies))

        ### Summarize training at the end ###
        logger.info("***** Training is finished *****")
        logger.info("Stopped training after {} epochs", n_epoch)
        return_metrics = {}
        for task_name, task_info in self._task_infos.items():
            nb_epoch_trained = int(task_info["total_n_batches_trained"] / task_info["n_tr_batches"])
            logger.info(
                "Trained {} for {} batches ~= {} epochs",
                task_name,
                task_info["total_n_batches_trained"],
                nb_epoch_trained,
            )
            return_metrics[task_name] = {
                "best_epoch": self._metric_infos[task_name]["best"][0],
                "nb_epoch_trained": nb_epoch_trained,
                "best_epoch_val_metrics": self._metric_infos[task_name]["best"][1],
            }

        training_elapsed_time = time.time() - training_start_time
        return_metrics["training_duration"] = time.strftime("%d:%H:%M:%S", time.gmtime(training_elapsed_time))
        return_metrics["nb_epoch_trained"] = n_epoch

        return return_metrics
Пример #6
0
    def _validation(self, n_epoch: int) -> Tuple[float, int]:
        ### Begin validation of the model ###
        logger.info("Validation - Begin")
        all_val_metrics = {}

        self._model.eval()  # Set the model into evaluation mode

        avg_accuracy = 0.0

        for task_idx, task in enumerate(self._task_list):
            logger.info("Validation - Task {}/{}: {}", task_idx + 1, self._n_tasks, task._name)

            val_loss = 0.0
            n_batches_val_this_epoch_this_task = 0
            n_val_batches = self._task_infos[task._name]["n_val_batches"]
            scheduler = self._schedulers[task._name]

            # Create tqdm generator for current tasks's validation
            data_iterator = task._data_iterator
            val_generator = data_iterator(task._validation_data, num_epochs=1, shuffle=False)
            val_generator_tqdm = tqdm.tqdm(val_generator, total=n_val_batches)

            # Iterate over each validation batch for this tasks
            for batch in val_generator_tqdm:
                n_batches_val_this_epoch_this_task += 1

                # Get the loss
                val_output_dict = self._forward(batch, task=task, for_training=False)
                loss = val_output_dict["stm_loss"]
                val_loss += loss.item()
                del loss

                # Get metrics for all progress so far, update tqdm, display description
                task_metrics = self._get_metrics(task=task)
                task_metrics["loss"] = float(val_loss / n_batches_val_this_epoch_this_task)
                description = training_util.description_from_metrics(task_metrics)
                val_generator_tqdm.set_description(description)

            # Get tasks validation metrics and store them in all_val_metrics
            task_metrics = self._get_metrics(task=task, reset=True)
            if task._name not in all_val_metrics:
                all_val_metrics[task._name] = {}
            for name, value in task_metrics.items():
                all_val_metrics[task._name][name] = value
            all_val_metrics[task._name]["loss"] = float(val_loss / n_batches_val_this_epoch_this_task)

            avg_accuracy += task_metrics["sentiment_acc"]

            # Tensorboard - Validation metrics for this epoch
            for metric_name, value in all_val_metrics[task._name].items():
                self._tensorboard.add_validation_scalar(
                    name="task_" + task._name + "/" + metric_name, value=value
                )

            ### Perform a patience check and update the history of validation metric for this tasks ###
            this_epoch_val_metric = all_val_metrics[task._name][task._val_metric]
            metric_history = self._metric_infos[task._name]["hist"]

            metric_history.append(this_epoch_val_metric)
            is_best_so_far, out_of_patience = self._check_history(
                metric_history=metric_history,
                cur_score=this_epoch_val_metric,
                should_decrease=task._val_metric_decreases,
            )

            if is_best_so_far:
                logger.info("Best model found for {}.", task._name)
                self._metric_infos[task._name]["best"] = (n_epoch, all_val_metrics)
            if out_of_patience and not self._metric_infos[task._name]["is_out_of_patience"]:
                self._metric_infos[task._name]["is_out_of_patience"] = True
                logger.info("Task {} is out of patience and vote to stop the training.", task._name)

            # The LRScheduler API is agnostic to whether your schedule requires a validation metric -
            # if it doesn't, the validation metric passed here is ignored.
            scheduler.step(this_epoch_val_metric, n_epoch)

        logger.info("Validation - End")
        return all_val_metrics, avg_accuracy
Пример #7
0
    def _train_epoch(self, total_n_tr_batches: int, sampling_prob: List, reverse=False, train_D=False) -> Dict[
        str, float]:
        self._model.train()  # Set the model to "train" mode.

        if reverse:
            logger.info("Training Generator- Begin")
        elif not train_D:
            logger.info("Training Init Generator- Begin")

        if train_D:
            logger.info("Training Discriminator- Begin")
        logger.info("reverse is {}, train_D is {}", reverse, train_D)

        ### Reset training and trained batches counter before new training epoch ###
        for _, task_info in self._task_infos.items():
            task_info["tr_loss_cum"] = 0.0
            task_info['stm_loss'] = 0.0
            task_info['p_d_loss'] = 0.0
            task_info['s_d_loss'] = 0.0
            task_info['valid_loss'] = 0.0
            task_info["n_batches_trained_this_epoch"] = 0
        all_tr_metrics = {}  # BUG TO COMPLETE COMMENT TO MAKE IT MORE CLEAR

        ### Start training epoch ###
        epoch_tqdm = tqdm.tqdm(range(total_n_tr_batches), total=total_n_tr_batches)
        histogram_parameters = set(self._model.get_parameters_for_histogram_tensorboard_logging())

        for step, _ in enumerate(epoch_tqdm):
            task_idx = np.argmax(np.random.multinomial(1, sampling_prob))
            task = self._task_list[task_idx]
            task_info = self._task_infos[task._name]

            ### One forward + backward pass ###
            # Call next batch to train
            batch = next(self._tr_generators[task._name])
            self._batch_num_total += 1
            task_info["n_batches_trained_this_epoch"] += 1

            # Load optimizer
            if not train_D:
                optimizer = self._optimizers[task._name]["all_params"]
            else:
                optimizer = self._optimizers[task._name]["exclude_share_encoder"]

            # Get the loss for this batch
            output_dict = self._forward(tensor_batch=batch, task=task, for_training=True, reverse=reverse)
            # if reverse or train_D:
            #     output_dict_fake = self._forward(tensor_batch=batch, task=task, for_training=True, reverse=True)
            # loss = output_dict["stm_loss"]
            # if train_D:
            #     loss = (output_dict["stm_loss"] + output_dict["s_d_loss"] + output_dict_fake["stm_loss"] +
            #             output_dict_fake["s_d_loss"]) / 2.0
            # if reverse:
            #     # loss = (output_dict["stm_loss"] + output_dict["p_d_loss"] + 0.005 * output_dict["s_d_loss"] +
            #     #         output_dict_fake["stm_loss"] + output_dict_fake["p_d_loss"] + 0.005 * output_dict_fake[
            #     #             "s_d_loss"]) / 2.0
            #     loss = (output_dict['loss'] + output_dict_fake['loss']) / 2.0
            loss = output_dict['loss']
            if self._gradient_accumulation_steps > 1:
                loss /= self._gradient_accumulation_steps
            loss.backward()
            task_info["tr_loss_cum"] += loss.item()
            task_info['stm_loss'] += output_dict['stm_loss'].item()
            task_info['p_d_loss'] += output_dict['p_d_loss'].item()
            task_info['s_d_loss'] += output_dict['s_d_loss'].item()
            task_info['valid_loss'] += output_dict['valid_loss'].item()
            # if reverse or train_D:
            #     task_info['stm_loss'] += output_dict_fake['stm_loss'].item()
            #     task_info['stm_loss'] /= 2.0
            #     task_info['p_d_loss'] += output_dict_fake['p_d_loss'].item()
            #     task_info['p_d_loss'] /= 2.0
            #     task_info['s_d_loss'] += output_dict_fake['s_d_loss'].item()
            #     task_info['s_d_loss'] /= 2.0
            #     task_info['valid_loss'] += output_dict_fake['valid_loss'].item()
            #     task_info['valid_loss'] /= 2.0
            del loss

            if (step + 1) % self._gradient_accumulation_steps == 0:
                batch_grad_norm = self._rescale_gradients()
                if self._tensorboard.should_log_histograms_this_batch():
                    param_updates = {name: param.detach().cpu().clone()
                                     for name, param in self._model.named_parameters()}
                    optimizer.step()
                    for name, param in self._model.named_parameters():
                        param_updates[name].sub_(param.detach().cpu())
                        update_norm = torch.norm(param_updates[name].view(-1, ))
                        param_norm = torch.norm(param.view(-1, )).cpu()
                        self._tensorboard.add_train_scalar("gradient_update/" + name,
                                                           update_norm / (param_norm + 1e-7))
                else:
                    optimizer.step()
                optimizer.zero_grad()

            ### Get metrics for all progress so far, update tqdm, display description ###
            task_metrics = self._get_metrics(task=task)
            task_metrics["loss"] = float(
                task_info["tr_loss_cum"] / (task_info["n_batches_trained_this_epoch"] + 0.000_001)
            )
            task_metrics["stm_loss"] = float(
                task_info["stm_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_001)
            )
            task_metrics["p_d_loss"] = float(
                task_info["p_d_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_001)
            )
            task_metrics["s_d_loss"] = float(
                task_info["s_d_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_001)
            )
            task_metrics["valid_loss"] = float(
                task_info["valid_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_001)
            )
            description = training_util.description_from_metrics(task_metrics)
            epoch_tqdm.set_description(task._name + ", " + description)

            # Log parameter values to Tensorboard
            if self._tensorboard.should_log_this_batch():
                self._tensorboard.log_parameter_and_gradient_statistics(self._model, batch_grad_norm)
                self._tensorboard.log_learning_rates(self._model, optimizer)

                self._tensorboard.log_metrics(
                    {"epoch_metrics/" + task._name + "/" + k: v for k, v in task_metrics.items()})

            if self._tensorboard.should_log_histograms_this_batch():
                self._tensorboard.log_histograms(self._model, histogram_parameters)
            self._global_step += 1

        ### Bookkeeping all the training metrics for all the tasks on the training epoch that just finished ###
        for task in self._task_list:
            task_info = self._task_infos[task._name]

            task_info["total_n_batches_trained"] += task_info["n_batches_trained_this_epoch"]
            task_info["last_log"] = time.time()

            task_metrics = self._get_metrics(task=task, reset=True)
            if task._name not in all_tr_metrics:
                all_tr_metrics[task._name] = {}
            for name, value in task_metrics.items():
                all_tr_metrics[task._name][name] = value
            all_tr_metrics[task._name]["loss"] = float(
                task_info["tr_loss_cum"] / (task_info["n_batches_trained_this_epoch"] + 0.000_000_01)
            )
            all_tr_metrics[task._name]["stm_loss"] = float(
                task_info["stm_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_000_01)
            )
            all_tr_metrics[task._name]["p_d_loss"] = float(
                task_info["p_d_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_000_01)
            )
            all_tr_metrics[task._name]["s_d_loss"] = float(
                task_info["s_d_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_000_01)
            )
            all_tr_metrics[task._name]["valid_loss"] = float(
                task_info["valid_loss"] / (task_info["n_batches_trained_this_epoch"] + 0.000_000_01)
            )

            # Tensorboard - Training metrics for this epoch
            for metric_name, value in all_tr_metrics[task._name].items():
                self._tensorboard.add_train_scalar(
                    name="task_" + task._name + "/" + metric_name, value=value
                )

        logger.info("Train - End")
        return all_tr_metrics
Пример #8
0
def tasks_and_vocab_from_params(
        params: Params,
        serialization_dir: str) -> Tuple[List[Task], Vocabulary]:
    """
    Load each of the tasks in the model from the ``params`` file
    and load the datasets associated with each of these tasks.
    Create the vocavulary from ``params`` using the concatenation of the ``datasets_for_vocab_creation``
    from each of the tasks specific dataset.

    Parameters
    ----------
    params: ``Params``
        A parameter object specifing an experiment.
    serialization_dir: ``str``
        Directory in which to save the model and its logs.
    Returns
    -------
    task_list: ``List[Task]``
        A list containing the tasks of the model to train.
    vocab: ``Vocabulary``
        The vocabulary fitted on the datasets_for_vocab_creation.
    """
    ### Instantiate the different tasks ###
    task_list = []
    instances_for_vocab_creation = itertools.chain()
    datasets_for_vocab_creation = {}
    task_keys = TASKS_NAME
    data_dir = "data/mtl-dataset"

    for key in task_keys:
        logger.info("Creating {}", key)
        task_data_params = Params({
            "dataset_reader": {
                "type": "semantic_review"
            },
            "train_data_path":
            os.path.join(data_dir, key + ".task.train"),
            "test_data_path":
            os.path.join(data_dir, key + ".task.test"),
            "validation_data_path":
            os.path.join(data_dir, key + ".task.test"),
            "datasets_for_vocab_creation": ["train", "validation", "test"]
        })

        task_description = Params({
            "task_name":
            key,
            "validation_metric_name":
            "{}_stm_acc".format(key)
        })

        task = Task.from_params(params=task_description)
        task_list.append(task)

        task_instances_for_vocab, task_datasets_for_vocab = task.load_data_from_params(
            params=task_data_params)
        instances_for_vocab_creation = itertools.chain(
            instances_for_vocab_creation, task_instances_for_vocab)
        datasets_for_vocab_creation[task._name] = task_datasets_for_vocab

    ### Create and save the vocabulary ###
    for task_name, task_dataset_list in datasets_for_vocab_creation.items():
        logger.info("Creating a vocabulary using {} data from {}.",
                    ", ".join(task_dataset_list), task_name)

    logger.info("Fitting vocabulary from dataset")
    vocab = Vocabulary.from_params(params.pop("vocabulary", {}),
                                   instances_for_vocab_creation)

    vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))
    logger.info("Vocabulary saved to {}",
                os.path.join(serialization_dir, "vocabulary"))

    return task_list, vocab
Пример #9
0
def train_model(multi_task_trainer: TransferMtlTrainer,
                recover: bool = False) -> Dict[str, Any]:
    """
    Launching the training of the multi-tasks model.

    Parameters
    ----------
    multi_task_trainer: ``MultiTaskTrainer``
        A trainer (similar to allennlp.training.trainer.Trainer) that can handle multi-tasks training.
    recover : ``bool``, optional (default=False)
        If ``True``, we will try to recover a training run from an existing serialization
        directory.  This is only intended for use when something actually crashed during the middle
        of a run.  For continuing training a model on new data, see the ``fine-tune`` command.
                    
    Returns
    -------
    metrics: ``Dict[str, Any]
        The different metrics summarizing the training of the model.
        It includes the validation and test (if necessary) metrics.
    """
    ### Train the multi-tasks model ###
    metrics = multi_task_trainer.train(recover=recover)

    task_list = multi_task_trainer._task_list
    serialization_dir = multi_task_trainer._serialization_dir
    model = multi_task_trainer._model

    # Evaluate the model on test data. if necessary This is a multi-tasks learning framework, the best validation
    # metrics for one tasks are not necessarily obtained from the same epoch for all the tasks, one epoch begin equal
    # to N forward+backward passes, where N is the total number of batches in all the training sets. We evaluate each
    # of the best model for each tasks (based on the validation metrics) for all the other tasks (which have a test
    # set).
    avg_accuracies = []
    for task in task_list:
        if not task._evaluate_on_test:
            continue

        logger.info("Task {} will be evaluated using the best epoch weights.",
                    task._name)
        assert (
            task._test_data is not None
        ), "Task {} wants to be evaluated on test dataset but no there is no test data loaded.".format(
            task._name)

        logger.info("Loading the best epoch weights for tasks {}", task._name)
        best_model_state_path = os.path.join(serialization_dir,
                                             "best_{}.th".format(task._name))
        best_model_state = torch.load(best_model_state_path)
        best_model = model
        best_model.load_state_dict(state_dict=best_model_state)

        test_metric_dict = {}
        avg_accuracy = 0.0

        for pair_task in task_list:
            if not pair_task._evaluate_on_test:
                continue

            logger.info(
                "Pair tasks {} is evaluated with the best model for {}",
                pair_task._name, task._name)
            test_metric_dict[pair_task._name] = {}
            test_metrics = evaluate(
                model=best_model,
                task_name=pair_task._name,
                instances=pair_task._test_data,
                data_iterator=pair_task._data_iterator,
                cuda_device=multi_task_trainer._cuda_device,
            )

            for metric_name, value in test_metrics.items():
                test_metric_dict[pair_task._name][metric_name] = value
            avg_accuracy += test_metrics["{}_stm_acc".format(pair_task._name)]
        logger.info("Average accuracy of task {} is {}", task._name,
                    avg_accuracy / len(task_list))
        avg_accuracies.append(avg_accuracy / len(task_list))

        metrics[task._name]["test"] = deepcopy(test_metric_dict)
        logger.info("Finished evaluation of tasks {}.", task._name)

    ### Dump validation and possibly test metrics ###
    metrics_json = json.dumps(metrics, indent=2)
    with open(os.path.join(serialization_dir, "metrics.json"),
              "w") as metrics_file:
        metrics_file.write(metrics_json)
    logger.info("Metrics: {}", metrics_json)
    logger.info("Average accuracy is {}",
                sum(avg_accuracies) / len(avg_accuracies))

    return metrics
Пример #10
0
    tasks, vocab = tasks_and_vocab_from_params(
        params=params, serialization_dir=serialization_dir)

    ### Load the data iterators for each tasks ###
    tasks = create_and_set_iterators(params=params,
                                     task_list=tasks,
                                     vocab=vocab)

    ### Load Regularizations ###
    regularizer = RegularizerApplicator.from_params(
        params.pop("regularizer", []))

    ### Create model ###
    model_params = params.pop("model")
    model = Model.from_params(vocab=vocab,
                              params=model_params,
                              regularizer=regularizer)
    ### Create multi-tasks trainer ###
    multi_task_trainer_params = params.pop("multi_task_trainer")
    trainer = TransferMtlTrainer.from_params(
        model=model,
        task_list=tasks,
        serialization_dir=serialization_dir,
        params=multi_task_trainer_params)

    ### Launch training ###
    metrics = train_model(multi_task_trainer=trainer, recover=args.recover)
    if metrics is not None:
        logger.info(
            "Training is finished ! Let's have a drink. It's on the house !")
Пример #11
0
        action="store_true",
        required=False,
        default=False,
        help="Whether or not evaluate using gold mentions in coreference",
    )
    args = parser.parse_args()

    params = Params.from_file(
        params_file=os.path.join(args.serialization_dir, "config.json"))

    ### Instantiate tasks ###
    task_list = []
    task_keys = [key for key in params.keys() if re.search("^task_", key)]

    for key in task_keys:
        logger.info("Creating {}", key)
        task_params = params.pop(key)
        task_description = task_params.pop("task_description")
        task_data_params = task_params.pop("data_params")

        task = Task.from_params(params=task_description)
        task_list.append(task)

        _, _ = task.load_data_from_params(params=task_data_params)

    ### Load Vocabulary from files ###
    vocab = Vocabulary.from_files(
        os.path.join(args.serialization_dir, "vocabulary"))
    logger.info("Vocabulary loaded")

    ### Load the data iterators ###
Пример #12
0
    def _train_epoch(self, total_n_tr_batches: int,
                     sampling_prob: List) -> Dict[str, float]:
        self._model.train()  # Set the model to "train" mode.

        ### Reset training and trained batches counter before new training epoch ###
        for _, task_info in self._task_infos.items():
            task_info["tr_loss_cum"] = 0.0
            task_info['stm_loss'] = 0.0
            task_info['s_d_loss'] = 0.0
            task_info['valid_loss'] = 0.0
            task_info["n_batches_trained_this_epoch"] = 0
        all_tr_metrics = {}  # BUG TO COMPLETE COMMENT TO MAKE IT MORE CLEAR

        ### Start training epoch ###
        epoch_tqdm = tqdm.tqdm(range(total_n_tr_batches),
                               total=total_n_tr_batches)
        histogram_parameters = set(
            self._model.get_parameters_for_histogram_tensorboard_logging())

        for step, _ in enumerate(epoch_tqdm):
            task_idx = np.argmax(np.random.multinomial(1, sampling_prob))
            task = self._task_list[task_idx]
            task_info = self._task_infos[task._name]

            ### One forward + backward pass ###
            # Call next batch to train
            batch = next(self._tr_generators[task._name])
            self._batch_num_total += 1
            task_info["n_batches_trained_this_epoch"] += 1

            # -------------------------------------
            # train sentiment classify
            # -------------------------------------

            # Load optimizer
            optimizer = self._optimizers["all_params"]

            # Get the loss for this batch
            output_dict = self._forward(tensor_batch=batch,
                                        task=task,
                                        for_training=True,
                                        train_stage="stm")

            loss = output_dict['loss']
            task_info['stm_loss'] += loss.item()
            loss.backward()
            del loss

            self._log_params_update(optimizer)

            # -------------------------------------
            # train domain classify
            # -------------------------------------

            # optimizer = self._optimizers["share_discriminator"]
            #
            # # Get the loss for this batch
            # output_dict = self._forward(tensor_batch=batch, task=task, for_training=True, train_stage="sd")
            #
            # loss = output_dict['loss']
            # task_info['s_d_loss'] += loss.item()
            # loss.backward()
            # del loss
            #
            # self._log_params_update(optimizer)

            # -------------------------------------
            # train adversarial domain classify
            # -------------------------------------
            optimizer = self._optimizers["all_params"]

            # Get the loss for this batch
            output_dict = self._forward(tensor_batch=batch,
                                        task=task,
                                        for_training=True,
                                        reverse=True,
                                        train_stage="sd")

            loss = output_dict['loss']
            loss = 0.05 * loss
            task_info['s_d_loss'] += loss.item()
            loss.backward()
            del loss

            self._log_params_update(optimizer)

            # -------------------------------------
            # train valid classify
            # -------------------------------------
            # optimizer = self._optimizers["valid_discriminator"]
            #
            # # Get the loss for this batch
            # output_dict = self._forward(tensor_batch=batch, task=task, for_training=True, train_stage="valid")
            #
            # loss = output_dict['loss']
            # task_info['valid_loss'] += loss.item()
            # loss.backward()
            # del loss
            #
            # self._log_params_update(optimizer)
            # -------------------------------------
            # train adversarial valid classify
            # -------------------------------------
            optimizer = self._optimizers["all_params"]

            # Get the loss for this batch
            output_dict = self._forward(tensor_batch=batch,
                                        task=task,
                                        for_training=True,
                                        reverse=True,
                                        train_stage="valid")

            loss = output_dict['loss']
            loss = 0.05 * loss
            task_info['valid_loss'] += loss.item()
            loss.backward()
            del loss

            self._log_params_update(optimizer)

            ### Get metrics for all progress so far, update tqdm, display description ###
            task_metrics = self._model.get_metrics(task_name=task._name)
            task_metrics["stm_loss"] = float(
                task_info["stm_loss"] /
                (task_info["n_batches_trained_this_epoch"] + 0.000_001))
            task_metrics["s_d_loss"] = float(
                task_info["s_d_loss"] /
                (task_info["n_batches_trained_this_epoch"] + 0.000_001))
            task_metrics["valid_loss"] = float(
                task_info["valid_loss"] /
                (task_info["n_batches_trained_this_epoch"] + 0.000_001))
            description = training_util.description_from_metrics(task_metrics)
            epoch_tqdm.set_description(description)

            self._log_params_values(optimizer, task._name, task_metrics,
                                    histogram_parameters)
            self._global_step += 1

        ### Bookkeeping all the training metrics for all the tasks on the training epoch that just finished ###
        for task in self._task_list:
            task_info = self._task_infos[task._name]

            task_info["total_n_batches_trained"] += task_info[
                "n_batches_trained_this_epoch"]
            task_info["last_log"] = time.time()

            task_metrics = self._model.get_metrics(task_name=task._name,
                                                   reset=True)
            if task._name not in all_tr_metrics:
                all_tr_metrics[task._name] = {}
            for name, value in task_metrics.items():
                all_tr_metrics[task._name][name] = value
            all_tr_metrics[task._name]["stm_loss"] = float(
                task_info["stm_loss"] /
                (task_info["n_batches_trained_this_epoch"] + 0.000_000_01))
            all_tr_metrics[task._name]["s_d_loss"] = float(
                task_info["s_d_loss"] /
                (task_info["n_batches_trained_this_epoch"] + 0.000_000_01))
            all_tr_metrics[task._name]["valid_loss"] = float(
                task_info["valid_loss"] /
                (task_info["n_batches_trained_this_epoch"] + 0.000_000_01))

            # Tensorboard - Training metrics for this epoch
            for metric_name, value in all_tr_metrics[task._name].items():
                self._tensorboard.add_train_scalar(name="task_" + task._name +
                                                   "/" + metric_name,
                                                   value=value)

        logger.info("Train - End")
        return all_tr_metrics
Пример #13
0
    def forward(
            self,  # type: ignore
            task_index: torch.IntTensor,
            reverse: torch.ByteTensor,
            epoch_trained: torch.IntTensor,
            for_training: torch.ByteTensor,
            tokens: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        embeddeds = self._encoder(task_index, tokens, epoch_trained,
                                  self._valid_discriminator, reverse,
                                  for_training)
        batch_size = get_batch_size(embeddeds["embedded_text"])

        sentiment_logits = self._sentiment_discriminator(
            embeddeds["embedded_text"])

        p_domain_logits = self._p_domain_discriminator(
            embeddeds["private_embedding"])

        # TODO set reverse = true
        s_domain_logits = self._s_domain_discriminator(
            embeddeds["share_embedding"], reverse=reverse)
        # TODO set reverse = true
        # TODO use share_embedding instead of domain_embedding
        valid_logits = self._valid_discriminator(embeddeds["domain_embedding"],
                                                 reverse=reverse)

        valid_label = embeddeds['valid']

        logits = [
            sentiment_logits, p_domain_logits, s_domain_logits, valid_logits
        ]

        # domain_logits = self._domain_discriminator(embedded_text)
        output_dict = {'logits': sentiment_logits}
        if label is not None:
            loss = self._loss(sentiment_logits, label)
            # task_index = task_index.unsqueeze(0)
            task_index = task_index.expand(batch_size)
            targets = [label, task_index, task_index, valid_label]
            # print(p_domain_logits.shape, task_index, task_index.shape)
            p_domain_loss = self._domain_loss(p_domain_logits, task_index)
            s_domain_loss = self._domain_loss(s_domain_logits, task_index)
            logger.info(
                "Share domain logits standard variation is {}",
                torch.mean(torch.std(F.softmax(s_domain_logits), dim=-1)))
            if self._label_smoothing is not None and self._label_smoothing > 0.0:
                valid_loss = sequence_cross_entropy_with_logits(
                    valid_logits,
                    valid_label.unsqueeze(0).cuda(),
                    torch.tensor(1).unsqueeze(0).cuda(),
                    average="token",
                    label_smoothing=self._label_smoothing)
            else:
                valid_loss = self._valid_loss(
                    valid_logits,
                    torch.zeros(2).scatter_(0, valid_label,
                                            torch.tensor(1.0)).cuda())
            output_dict['stm_loss'] = loss
            output_dict['p_d_loss'] = p_domain_loss
            output_dict['s_d_loss'] = s_domain_loss
            output_dict['valid_loss'] = valid_loss
            # TODO add share domain logits std loss
            output_dict['loss'] = loss + p_domain_loss + 0.005 * s_domain_loss\
                                  # + 0.005 * valid_loss

            # + torch.mean(torch.std(s_domain_logits, dim=1))
            # output_dict['loss'] = loss + p_domain_loss + 0.005 * s_domain_loss

            for (metric, logit, target) in zip(self.metrics.values(), logits,
                                               targets):
                metric(logit, target)

        return output_dict