def predict(self, loader):
        """
        Uses trained model to make predictions on the loader.

        Args:
            loader: the DataLoader containing the set to run the prediction and evaluation.         

        Returns:
            Matrices (logits, labels, softmax_logits)
        """

        self.model.eval()
        all_logits = []
        all_labels = []
        all_softmax_logits = []
        for idx, batch in tqdm(enumerate(loader),
                               total=len(loader),
                               desc="Predicting"):
            for k, v in batch.items():
                batch[k] = v.to(self.device)

            with torch.no_grad():
                if self.task_type == "classification":
                    outputs = self.model(**batch)
                    _, logits = outputs[:2]
                    all_labels += batch["labels"].int().tolist(
                    )  # this is required because of the weak supervision
                    all_logits += logits[:, 1].tolist()
                    all_softmax_logits += torch.softmax(logits,
                                                        dim=1)[:, 1].tolist()

                elif self.task_type == "generation":
                    outputs = self.model(**batch)
                    _, token_logits = outputs[:2]
                    relevant_token_id = self.tokenizer.encode("relevant")[0]
                    not_relevant_token_id = self.tokenizer.encode(
                        "not_relevant")[0]

                    pred_relevant = token_logits[0:, 0, relevant_token_id]
                    pred_not_relevant = token_logits[0:, 0,
                                                     not_relevant_token_id]
                    both = torch.stack((pred_relevant, pred_not_relevant))

                    all_logits += pred_relevant.tolist()
                    all_labels += [
                        1 if (l[0] == relevant_token_id) else 0
                        for l in batch["labels"]
                    ]
                    all_softmax_logits += torch.softmax(both,
                                                        dim=0)[0].tolist()

            if self.num_validation_batches != -1 and idx > self.num_validation_batches:
                break

        #accumulates per query
        all_labels = utils.acumulate_list_multiple_relevant(all_labels)
        all_logits = utils.acumulate_l1_by_l2(all_logits, all_labels)
        all_softmax_logits = utils.acumulate_l1_by_l2(all_softmax_logits,
                                                      all_labels)
        return all_logits, all_labels, all_softmax_logits
Ejemplo n.º 2
0
    def predict(self, loader):
        """
        Uses trained model to make predictions on the loader.

        Args:
            loader: the DataLoader containing the set to run the prediction and evaluation.         

        Returns:
            A tuple of (logits, labels). For example:
            ([[0.01, 0.12], [1.2, 0.9]])
        """

        self.model.eval()
        all_logits = []
        all_labels = []
        for idx, batch in tqdm(enumerate(loader), total=len(loader)):
            for k, v in batch.items():
                batch[k] = v.to(self.device)

            with torch.no_grad():
                outputs = self.model(**batch)
                _, logits = outputs[:2]
                all_labels+=batch["labels"].tolist()
                all_logits+=logits[:, 1].tolist()
                import pdb
                pdb.set_trace()

            if self.num_validation_instances!=-1 and idx > self.num_validation_instances:
                break

        #accumulates per query
        all_labels = utils.acumulate_list_multiple_relevant(all_labels)
        all_logits = utils.acumulate_l1_by_l2(all_logits, all_labels)
        return all_logits, all_labels
    def predict(self, loader):
        """
        Uses trained model to make predictions on the loader.

        Args:
            loader: the DataLoader containing the set to run the prediction and evaluation.         

        Returns:
            A tuple of (logits, labels). For example:
            ([[0.01, 0.12], [1.2, 0.9]])
        """

        self.model.eval()
        all_logits = []
        all_labels = []
        all_ids = []
        all_queries = []
        softmax = nn.Sigmoid()
        for idx, batch in tqdm(enumerate(loader), total=len(loader)):
            for k, v in batch.items():
                if k!='query':
                    batch[k] = v.to(self.device)
            with torch.no_grad():
                if self.task_type == "classification":
                    outputs = self.model(attention_mask=batch['attention_mask'], input_ids=batch['input_ids'],
                                         token_type_ids=batch['token_type_ids'], labels=batch['labels'])

                    _, logits = outputs[:2]
                    all_labels+=batch["labels"].tolist()
                    all_logits+=softmax(logits[:, 1]).tolist()
                    all_ids+=batch["target_doc_id"].tolist()
                    all_queries+=batch["query"]



                elif self.task_type == "generation":
                    outputs = self.model(attention_mask=batch['attention_mask'], input_ids=batch['input_ids'],
                                         token_type_ids=batch['token_type_ids'], labels=batch['labels'])
                    _, token_logits = outputs[:2]
                    relevant_token_id = self.tokenizer.encode("relevant")[0]
                    not_relevant_token_id = self.tokenizer.encode("not_relevant")[0]

                    pred_relevant = token_logits[0:, 0 , relevant_token_id]
                    pred_not_relevant = token_logits[0:, 0 , not_relevant_token_id]
                    pred = pred_relevant-pred_not_relevant                    

                    all_logits+=pred.tolist()
                    all_labels+=[1 if (l[0] == relevant_token_id) else 0 for l in batch["lm_labels"]]


            if self.num_validation_instances!=-1 and idx > self.num_validation_instances:
                break

        #accumulates per query
        all_labels = utils.acumulate_list_multiple_relevant(all_labels)
        all_logits_without_acc=all_logits.copy()
        all_logits = utils.acumulate_l1_by_l2(all_logits, all_labels)
        return all_logits, all_labels ,all_ids ,all_queries , all_logits_without_acc
    def predict_with_uncertainty(self, loader, foward_passes):
        """
        Uses trained model to make predictions on the loader with uncertainty estimations.

        This methods uses MC dropout to get the predicted relevance (mean) and uncertainty (variance)
        by enabling dropout at test time and making K foward passes.

        See "Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning"
        https://arxiv.org/abs/1506.02142.

        Args:
            loader: DataLoader containing the set to run the prediction and evaluation.         
            foward_passes: int indicating the number of foward prediction passes for each instance.

        Returns:
            Matrices (logits, labels, softmax_logits, foward_passes_logits, uncertainties):
            The logits (mean) for every instance, labels, softmax_logits (mean) all predictions
            obtained during f_passes (foward_passes_logits) and the uncertainties (variance).
        """
        def enable_dropout(model):
            for module in model.modules():
                if module.__class__.__name__.startswith('Dropout'):
                    module.train()

        self.model.eval()
        enable_dropout(self.model)

        if self.task_type == "generation":
            relevant_token_id = self.tokenizer.encode("relevant")[0]
            not_relevant_token_id = self.tokenizer.encode("not_relevant")[0]

        logits = []
        labels = []
        uncertainties = []
        softmax_logits = []
        foward_passes_logits = [[] for i in range(foward_passes)
                                ]  # foward_passes X queries
        for idx, batch in tqdm(enumerate(loader), total=len(loader)):
            for k, v in batch.items():
                batch[k] = v.to(self.device)

            with torch.no_grad():
                fwrd_predictions = []
                fwrd_softmax_predictions = []
                if self.task_type == "classification":
                    labels += batch["labels"].tolist()
                    for i, f_pass in enumerate(range(foward_passes)):
                        outputs = self.model(**batch)
                        _, batch_logits = outputs[:2]

                        fwrd_predictions.append(batch_logits[:, 1].tolist())
                        fwrd_softmax_predictions.append(
                            torch.softmax(batch_logits, dim=1)[:, 1].tolist())
                        foward_passes_logits[i] += batch_logits[:, 1].tolist()
                elif self.task_type == "generation":
                    labels += [
                        1 if (l[0] == relevant_token_id) else 0
                        for l in batch["labels"]
                    ]
                    for i, f_pass in enumerate(range(foward_passes)):
                        outputs = self.model(**batch)
                        _, token_logits = outputs[:2]
                        pred_relevant = token_logits[0:, 0, relevant_token_id]
                        pred_not_relevant = token_logits[0:, 0,
                                                         not_relevant_token_id]
                        both = torch.stack((pred_relevant, pred_not_relevant))

                        fwrd_predictions.append(pred_relevant.tolist())
                        fwrd_softmax_predictions.append(
                            torch.softmax(both, dim=0)[0].tolist())
                        foward_passes_logits[i] += pred_relevant.tolist()

                logits += np.array(fwrd_predictions).mean(axis=0).tolist()
                uncertainties += np.array(fwrd_predictions).var(
                    axis=0).tolist()
                softmax_logits += np.array(fwrd_softmax_predictions).mean(
                    axis=0).tolist()
            if self.num_validation_batches != -1 and idx > self.num_validation_batches:
                break

        #accumulates per query
        labels = utils.acumulate_list_multiple_relevant(labels)
        logits = utils.acumulate_l1_by_l2(logits, labels)
        uncertainties = utils.acumulate_l1_by_l2(uncertainties, labels)
        softmax_logits = utils.acumulate_l1_by_l2(softmax_logits, labels)
        for i, foward_logits in enumerate(foward_passes_logits):
            foward_passes_logits[i] = utils.acumulate_l1_by_l2(
                foward_logits, labels)

        return logits, labels, softmax_logits, foward_passes_logits, uncertainties