示例#1
0
    def compute_correct(self,
                        model: NeuralModelBase,
                        batch: MiniBatch,
                        verbose=False,
                        threshold=0.5):
        training_mask = model.padding_mask(batch)
        mask_valid = batch.masks["mask_valid"]
        Y = batch.Y
        predictions = model.predict_with_reject(batch,
                                                self.dataset.reject_token_id,
                                                threshold=threshold)

        correct = (predictions == Y) * training_mask
        if len(Y.size()) == 2:
            """
            The adversary assumes that the order of predicted labels in the adversary and original batch are the same.
            All the modifications respect this, i.e., while code is added or removed, the order and 
            number of predicted labels (via 'mask_valid') stays the same.
            However, since the sample length might change, we need to ensure that masked_select is performed
            batch_first (B x N rather than N x B) as otherwise the results are no guaranteed to preserve the order.              
            """
            # TODO: check that the batch is created with batch_first=False (this is the default with torchtext)
            mask_valid = mask_valid.t()
            correct = correct.t().masked_select(mask_valid)
            predictions = predictions.t().masked_select(mask_valid)
            Y = Y.t().masked_select(mask_valid)
        else:
            correct = correct.masked_select(mask_valid)
            predictions = predictions.masked_select(mask_valid)
            Y = Y.masked_select(mask_valid)
        return correct, predictions, Y
示例#2
0
def get_rejection_thresholds(
    it, model: NeuralModelBase, dataset: Dataset, precision_thresholds: Iterable[float]
):
    num_bins = 1000
    # stats = [SimpleNamespace(correct=0, total=0) for _ in range(num_bins + 1)]

    num_correct = torch.zeros(num_bins)
    num_total = torch.zeros(num_bins)
    for batch in tqdm.tqdm(it, ncols=100, leave=False):
        _, best_predictions, reject_probs = model.predict_probs_with_reject(
            batch, reject_id=dataset.reject_token_id
        )
        mask = model.padding_mask(batch, mask_field="mask_valid")
        targets = batch.Y

        best_predictions = best_predictions.masked_select(mask)
        reject_probs = reject_probs.masked_select(mask).cpu()
        targets = targets.masked_select(mask)

        is_corrects = (targets == best_predictions).cpu()

        num_total.add_(torch.histc(reject_probs, bins=num_bins, min=0, max=1))
        num_correct.add_(
            torch.histc(
                reject_probs.masked_select(is_corrects), bins=num_bins, min=0, max=1
            )
        )

    def precision(stat):
        if stat.total == 0:
            return 0
        return stat.correct * 1.0 / stat.total

    thresholds = [SimpleNamespace(h=None, size=0) for _ in precision_thresholds]
    rolling_stat = SimpleNamespace(correct=0, total=0)
    for i, correct, total in zip(
        itertools.count(), num_correct.numpy(), num_total.numpy()
    ):
        for t, precision_threshold in zip(thresholds, precision_thresholds):
            if precision_threshold <= precision(rolling_stat):
                # update threshold if it's not set or the number of samples increased
                if t.h is None or t.size * 1.01 < rolling_stat.total:
                    t.h = i / float(num_bins)
                    t.size = int(rolling_stat.total)

        rolling_stat.correct += correct
        rolling_stat.total += total

    Logger.debug(
        "Thresholds: {}, sizes: {}".format(
            [t.h for t in thresholds], [t.size for t in thresholds]
        )
    )
    return thresholds
示例#3
0
    def visualize_adversarial(self, batch: MiniBatch, model: NeuralModelBase,
                              masks, colors):
        assert isinstance(model, GraphModel)
        samples = self.dataset.samples_for_batch(batch)

        training_mask = model.padding_mask(batch, mask_field="mask_valid")

        g = batch.X
        mask_data = {}
        mask_names = []
        for idx, mask in enumerate(masks):
            mask_all = torch.zeros_like(training_mask)
            mask_all[training_mask] = mask
            mask_data[str(idx)] = mask_all
            mask_names.append(str(idx))

        for name, mask in mask_data.items():
            assert name not in g.ndata
            g.ndata[name] = mask
        trees = dgl.unbatch(g)
        for name in mask_data.keys():
            del g.ndata[name]

        for tree, sample in zip(trees, samples):
            mask_labels = [
                MaskLabel.from_values(tree.ndata[name], color)
                for name, color in zip(mask_names, colors)
            ]
            TreeVisualization.visualize(
                sample,
                ["types", "values"],
                self.dataset.dtrain.fields,
                labels=mask_labels + [
                    FieldLabel(self.dataset.dtrain.fields["types"],
                               tree.ndata["types"]),
                    FieldLabel(self.dataset.dtrain.fields["values"],
                               tree.ndata["values"]),
                    # TODO: topk
                    # TopkLabel.from_iter([batch], sample, model, self.dataset,
                    #                     mask=MaskLabel.from_sample(sample, 'mask_valid')),
                ],
            )
            input()
            break
示例#4
0
def print_rejection_thresholds(it, model: NeuralModelBase, dataset: Dataset):
    num_correct = 0
    num_total = 0
    thresholds = np.arange(0.1, 1.1, 0.1)
    stats = collections.defaultdict(lambda: SimpleNamespace(correct=0, total=0))
    for batch in tqdm.tqdm(it, ncols=100, leave=False):
        _, best_predictions, reject_probs = model.predict_probs_with_reject(
            batch, reject_id=dataset.reject_token_id
        )
        mask = model.padding_mask(batch, mask_field="mask_valid")
        targets = batch.Y

        best_predictions = best_predictions.masked_select(mask)
        reject_probs = reject_probs.masked_select(mask)
        targets = targets.masked_select(mask)

        is_correct = targets == best_predictions
        num_correct += torch.sum(is_correct).item()
        num_total += targets.numel()

        for h in thresholds:
            h_mask = reject_probs <= h
            stats[h].total += torch.sum(h_mask).item()
            stats[h].correct += torch.sum(is_correct.masked_select(h_mask)).item()

    for h in thresholds:
        Logger.debug(
            "Threshold {:5.2f}: {:6d}/{:6d} ({:.2f}%)".format(
                h,
                stats[h].correct,
                stats[h].total,
                acc(stats[h].correct, stats[h].total),
            )
        )

    Logger.debug(
        "{:6d}/{:6d} ({:.2f}%)".format(
            num_correct, num_total, acc(num_correct, num_total)
        )
    )