Ejemplo n.º 1
0
    def test_model_on_dataset(self, args):
        model = parse_model_from_args(args)
        dataset = parse_dataset_from_args(args)

        preds = []
        ground_truth_outputs = []
        i = 0

        while i < args.num_examples:
            dataset_batch = dataset[i:min(args.num_examples, i +
                                          args.model_batch_size)]
            batch_inputs = []
            for (text_input, ground_truth_output) in dataset_batch:
                attacked_text = textattack.shared.AttackedText(text_input)
                batch_inputs.append(attacked_text.tokenizer_input)
                ground_truth_outputs.append(ground_truth_output)
            batch_preds = self.get_preds(model, batch_inputs)

            if not isinstance(batch_preds, torch.Tensor):
                batch_preds = torch.Tensor(batch_preds)

            preds.extend(batch_preds)
            i += args.model_batch_size

        preds = torch.stack(preds).squeeze().cpu()
        ground_truth_outputs = torch.tensor(ground_truth_outputs).cpu()

        logger.info(f"Got {len(preds)} predictions.")

        if preds.ndim == 1:
            # if preds is just a list of numbers, assume regression for now
            # TODO integrate with `textattack.metrics` package
            pearson_correlation, _ = scipy.stats.pearsonr(
                ground_truth_outputs, preds)
            spearman_correlation, _ = scipy.stats.spearmanr(
                ground_truth_outputs, preds)

            logger.info(f"Pearson correlation = {_cb(pearson_correlation)}")
            logger.info(f"Spearman correlation = {_cb(spearman_correlation)}")
        else:
            guess_labels = preds.argmax(dim=1)
            successes = (guess_labels == ground_truth_outputs).sum().item()
            perc_accuracy = successes / len(preds) * 100.0
            perc_accuracy = "{:.2f}%".format(perc_accuracy)
            logger.info(
                f"Correct {successes}/{len(preds)} ({_cb(perc_accuracy)})")
Ejemplo n.º 2
0
    def run(self, args):
        UPPERCASE_LETTERS_REGEX = re.compile("[A-Z]")

        args.model, args.model_from_file = None, None  # set model and model_from_file to None for
        # parse_dataset_from_args to work
        dataset = parse_dataset_from_args(args)

        num_words = []
        attacked_texts = []
        data_all_lowercased = True
        outputs = []
        for inputs, output in dataset:
            at = textattack.shared.AttackedText(inputs)
            if data_all_lowercased:
                # Test if any of the letters in the string are lowercase.
                if re.search(UPPERCASE_LETTERS_REGEX, at.text):
                    data_all_lowercased = False
            attacked_texts.append(at)
            num_words.append(len(at.words))
            outputs.append(output)

        logger.info(f"Number of samples: {_cb(len(attacked_texts))}")
        logger.info("Number of words per input:")
        num_words = np.array(num_words)
        logger.info(f'\t{("total:").ljust(8)} {_cb(num_words.sum())}')
        mean_words = f"{num_words.mean():.2f}"
        logger.info(f'\t{("mean:").ljust(8)} {_cb(mean_words)}')
        std_words = f"{num_words.std():.2f}"
        logger.info(f'\t{("std:").ljust(8)} {_cb(std_words)}')
        logger.info(f'\t{("min:").ljust(8)} {_cb(num_words.min())}')
        logger.info(f'\t{("max:").ljust(8)} {_cb(num_words.max())}')
        logger.info(f"Dataset lowercased: {_cb(data_all_lowercased)}")

        logger.info("First sample:")
        print(attacked_texts[0].printable_text(), "\n")
        logger.info("Last sample:")
        print(attacked_texts[-1].printable_text(), "\n")

        logger.info(f"Found {len(set(outputs))} distinct outputs.")
        if len(outputs) < 20:
            print(sorted(set(outputs)))

        logger.info("Most common outputs:")
        for i, (key, value) in enumerate(collections.Counter(outputs).most_common(20)):
            print("\t", str(key)[:5].ljust(5), f" ({value})")
Ejemplo n.º 3
0
    def test_model_on_dataset(self, args):
        model = parse_model_from_args(args)
        dataset = parse_dataset_from_args(args)

        dataset = dataset[:args.num_examples]
        attacked_texts = [
            textattack.shared.AttackedText(text) for text, _ in dataset
        ]
        ground_truth_outputs = [label for _, label in dataset]

        inputs = textattack.shared.utils.batch_tokenize(
            model.tokenizer, attacked_texts, batch_size=args.batch_size)
        preds = self.get_preds(model, inputs, batch_size=args.batch_size)

        preds = preds.cpu()
        ground_truth_outputs = torch.tensor(ground_truth_outputs).cpu()

        logger.info(f"Got {len(preds)} predictions.")

        if preds.ndim == 1:
            # if preds is just a list of numbers, assume regression for now
            # TODO integrate with `textattack.metrics` package
            pearson_correlation, _ = scipy.stats.pearsonr(
                ground_truth_outputs, preds)
            spearman_correlation, _ = scipy.stats.spearmanr(
                ground_truth_outputs, preds)

            logger.info(f"Pearson correlation = {_cb(pearson_correlation)}")
            logger.info(f"Spearman correlation = {_cb(spearman_correlation)}")
        else:
            guess_labels = preds.argmax(dim=1)
            successes = (guess_labels == ground_truth_outputs).sum().item()
            perc_accuracy = successes / len(preds) * 100.0
            perc_accuracy = "{:.2f}%".format(perc_accuracy)
            logger.info(
                f"Successes {successes}/{len(preds)} ({_cb(perc_accuracy)})")