Exemplo n.º 1
0
    def compute_correct(self,
                        model: NeuralModelBase,
                        batch: MiniBatch,
                        verbose=False,
                        threshold=0.5):
        training_mask = model.padding_mask(batch)
        mask_valid = batch.masks["mask_valid"]
        Y = batch.Y
        predictions = model.predict_with_reject(batch,
                                                self.dataset.reject_token_id,
                                                threshold=threshold)

        correct = (predictions == Y) * training_mask
        if len(Y.size()) == 2:
            """
            The adversary assumes that the order of predicted labels in the adversary and original batch are the same.
            All the modifications respect this, i.e., while code is added or removed, the order and 
            number of predicted labels (via 'mask_valid') stays the same.
            However, since the sample length might change, we need to ensure that masked_select is performed
            batch_first (B x N rather than N x B) as otherwise the results are no guaranteed to preserve the order.              
            """
            # TODO: check that the batch is created with batch_first=False (this is the default with torchtext)
            mask_valid = mask_valid.t()
            correct = correct.t().masked_select(mask_valid)
            predictions = predictions.t().masked_select(mask_valid)
            Y = Y.t().masked_select(mask_valid)
        else:
            correct = correct.masked_select(mask_valid)
            predictions = predictions.masked_select(mask_valid)
            Y = Y.masked_select(mask_valid)
        return correct, predictions, Y
Exemplo n.º 2
0
def get_rejection_thresholds(
    it, model: NeuralModelBase, dataset: Dataset, precision_thresholds: Iterable[float]
):
    num_bins = 1000
    # stats = [SimpleNamespace(correct=0, total=0) for _ in range(num_bins + 1)]

    num_correct = torch.zeros(num_bins)
    num_total = torch.zeros(num_bins)
    for batch in tqdm.tqdm(it, ncols=100, leave=False):
        _, best_predictions, reject_probs = model.predict_probs_with_reject(
            batch, reject_id=dataset.reject_token_id
        )
        mask = model.padding_mask(batch, mask_field="mask_valid")
        targets = batch.Y

        best_predictions = best_predictions.masked_select(mask)
        reject_probs = reject_probs.masked_select(mask).cpu()
        targets = targets.masked_select(mask)

        is_corrects = (targets == best_predictions).cpu()

        num_total.add_(torch.histc(reject_probs, bins=num_bins, min=0, max=1))
        num_correct.add_(
            torch.histc(
                reject_probs.masked_select(is_corrects), bins=num_bins, min=0, max=1
            )
        )

    def precision(stat):
        if stat.total == 0:
            return 0
        return stat.correct * 1.0 / stat.total

    thresholds = [SimpleNamespace(h=None, size=0) for _ in precision_thresholds]
    rolling_stat = SimpleNamespace(correct=0, total=0)
    for i, correct, total in zip(
        itertools.count(), num_correct.numpy(), num_total.numpy()
    ):
        for t, precision_threshold in zip(thresholds, precision_thresholds):
            if precision_threshold <= precision(rolling_stat):
                # update threshold if it's not set or the number of samples increased
                if t.h is None or t.size * 1.01 < rolling_stat.total:
                    t.h = i / float(num_bins)
                    t.size = int(rolling_stat.total)

        rolling_stat.correct += correct
        rolling_stat.total += total

    Logger.debug(
        "Thresholds: {}, sizes: {}".format(
            [t.h for t in thresholds], [t.size for t in thresholds]
        )
    )
    return thresholds
Exemplo n.º 3
0
def load_model(model: NeuralModelBase, args, model_id):
    import torch

    checkpoint_file = os.path.join(checkpoint_dir(args),
                                   checkpoint_name(args, model_id))
    print("checkpoint_file", checkpoint_file)
    if not os.path.exists(checkpoint_file):
        return False

    Logger.debug("Loading model from {}".format(checkpoint_file))
    data = torch.load(checkpoint_file)
    model.load_state_dict(data)
    return True
Exemplo n.º 4
0
def save_model(model: NeuralModelBase, args, model_id):
    import torch

    checkpoint_file = os.path.join(checkpoint_dir(args),
                                   checkpoint_name(args, model_id))
    Logger.debug("Saving model to {}".format(checkpoint_file))
    torch.save(model.state_dict(), checkpoint_file)
Exemplo n.º 5
0
def train_base_model(
    model: NeuralModelBase,
    dataset: Dataset,
    num_epochs,
    train_iter,
    valid_iter,
    lr=0.001,
    verbose=True,
):
    valid_iters = [valid_iter
                   ] if not isinstance(valid_iter, list) else valid_iter
    Logger.start_scope("Training Model")
    opt = optim.Adam(model.parameters(), lr=lr)
    model.opt = opt
    loss_function = nn.CrossEntropyLoss(reduction="none")
    model.loss_function = loss_function

    train_prec, valid_prec = None, None
    for epoch in range(num_epochs):
        Logger.start_scope("Epoch {}".format(epoch))
        model.fit(train_iter, opt, loss_function, mask_field="mask_valid")

        for valid_iter in valid_iters:
            valid_stats = model.accuracy(valid_iter,
                                         dataset.TARGET,
                                         verbose=verbose)
            valid_prec = valid_stats["mask_valid_noreject_acc"]
            Logger.debug(f"valid_prec: {valid_prec}")
        Logger.end_scope()

    train_stats = model.accuracy(train_iter, dataset.TARGET, verbose=False)
    train_prec = train_stats["mask_valid_noreject_acc"]
    Logger.debug(f"train_prec: {train_prec}, valid_prec: {valid_prec}")
    Logger.end_scope()
    return train_prec, valid_prec
Exemplo n.º 6
0
def print_rejection_thresholds(it, model: NeuralModelBase, dataset: Dataset):
    num_correct = 0
    num_total = 0
    thresholds = np.arange(0.1, 1.1, 0.1)
    stats = collections.defaultdict(lambda: SimpleNamespace(correct=0, total=0))
    for batch in tqdm.tqdm(it, ncols=100, leave=False):
        _, best_predictions, reject_probs = model.predict_probs_with_reject(
            batch, reject_id=dataset.reject_token_id
        )
        mask = model.padding_mask(batch, mask_field="mask_valid")
        targets = batch.Y

        best_predictions = best_predictions.masked_select(mask)
        reject_probs = reject_probs.masked_select(mask)
        targets = targets.masked_select(mask)

        is_correct = targets == best_predictions
        num_correct += torch.sum(is_correct).item()
        num_total += targets.numel()

        for h in thresholds:
            h_mask = reject_probs <= h
            stats[h].total += torch.sum(h_mask).item()
            stats[h].correct += torch.sum(is_correct.masked_select(h_mask)).item()

    for h in thresholds:
        Logger.debug(
            "Threshold {:5.2f}: {:6d}/{:6d} ({:.2f}%)".format(
                h,
                stats[h].correct,
                stats[h].total,
                acc(stats[h].correct, stats[h].total),
            )
        )

    Logger.debug(
        "{:6d}/{:6d} ({:.2f}%)".format(
            num_correct, num_total, acc(num_correct, num_total)
        )
    )
Exemplo n.º 7
0
    def visualize_adversarial(self, batch: MiniBatch, model: NeuralModelBase,
                              masks, colors):
        assert isinstance(model, GraphModel)
        samples = self.dataset.samples_for_batch(batch)

        training_mask = model.padding_mask(batch, mask_field="mask_valid")

        g = batch.X
        mask_data = {}
        mask_names = []
        for idx, mask in enumerate(masks):
            mask_all = torch.zeros_like(training_mask)
            mask_all[training_mask] = mask
            mask_data[str(idx)] = mask_all
            mask_names.append(str(idx))

        for name, mask in mask_data.items():
            assert name not in g.ndata
            g.ndata[name] = mask
        trees = dgl.unbatch(g)
        for name in mask_data.keys():
            del g.ndata[name]

        for tree, sample in zip(trees, samples):
            mask_labels = [
                MaskLabel.from_values(tree.ndata[name], color)
                for name, color in zip(mask_names, colors)
            ]
            TreeVisualization.visualize(
                sample,
                ["types", "values"],
                self.dataset.dtrain.fields,
                labels=mask_labels + [
                    FieldLabel(self.dataset.dtrain.fields["types"],
                               tree.ndata["types"]),
                    FieldLabel(self.dataset.dtrain.fields["values"],
                               tree.ndata["values"]),
                    # TODO: topk
                    # TopkLabel.from_iter([batch], sample, model, self.dataset,
                    #                     mask=MaskLabel.from_sample(sample, 'mask_valid')),
                ],
            )
            input()
            break
Exemplo n.º 8
0
def train_model(
    model: NeuralModelBase,
    dataset: Dataset,
    num_epochs,
    train_iter,
    valid_iter,
    lr=0.001,
    weight=None,
    target_o=1.0,
):
    # model.reset_parameters()
    opt = optim.Adam(model.parameters(), lr=lr)
    Logger.start_scope("Training Model")

    o_base = len(dataset.TARGET.vocab) - 4  # 'reject', '<unk>', '<pad>'
    loss_function = RejectionCrossEntropyLoss(
        o_base,
        len(dataset.TARGET.vocab),
        dataset.reject_token_id,
        reduction="none",
        weight=weight,
    )
    model.loss_function = loss_function
    model.opt = opt

    step = 1.0 / (num_epochs // 2)
    schedule = [
        f * o_base + (1 - f) * 1.0 for f in np.arange(start=1.0, stop=0.0, step=-step)
    ]
    schedule += [
        f * ((1.0 + schedule[-1]) / 2) + (1 - f) * target_o
        for f in np.arange(start=1.0, stop=0.0, step=-step)
    ]
    schedule += [target_o] * (num_epochs // 2)

    train_prec, valid_prec = None, None
    for epoch, o_upper in enumerate(schedule):
        Logger.start_scope("Epoch {}, o_upper={:.3f}".format(epoch, o_upper))
        loss_function.o = o_upper
        model.fit(train_iter, opt, loss_function, mask_field="mask_valid")

        valid_stats = model.accuracy(
            valid_iter, dataset.TARGET
        )  # , thresholds=[0.5, 0.8, 0.9, 0.95])
        valid_prec = valid_stats["mask_valid_noreject_acc"]
        Logger.debug(f"valid_prec: {valid_prec}")
        Logger.end_scope()

        # Logger.start_scope('Print Rejection Thresholds')
        # print_rejection_thresholds(train_iter, model, dataset)
        # print_rejection_thresholds(valid_iter, model, dataset)
        # Logger.end_scope()

        # Logger.start_scope('Get Rejection Thresholds')
        # get_rejection_thresholds(train_iter, model, dataset, [1.00, 0.99, 0.95, 0.9, 0.8])
        # get_rejection_thresholds(valid_iter, model, dataset, [1.00, 0.99, 0.95, 0.9, 0.8])
        # Logger.end_scope()

    train_stats = model.accuracy(train_iter, dataset.TARGET, verbose=False)
    train_prec = train_stats["mask_valid_noreject_acc"]
    Logger.debug(f"train_prec: {train_prec}, valid_prec: {valid_prec}")
    Logger.end_scope()
    # exit(0)
    return train_prec, valid_prec
Exemplo n.º 9
0
    def __attack_graph(
        self,
        model: NeuralModelBase,
        batch: MiniBatch,
        mask: torch.Tensor,
        num_samples,
        shuffle: ShuffleStrategy,
        mode: AdversarialMode,
    ):
        tree_sizes = batch.lengths
        offsets = (np.cumsum(tree_sizes) - np.array(tree_sizes)).tolist()

        if mask is not None:
            tree_masks = torch.split(mask, tree_sizes)
            tree_nodes = [
                np.flatnonzero(tree_mask.cpu().numpy())
                for tree_mask in tree_masks
            ]
        else:
            tree_nodes = [None] * len(tree_sizes)

        g = batch.X
        assert g.number_of_nodes() == sum(tree_sizes)
        original_values = g.ndata["values"]

        tree_rules, tree_num_samples = self._initialize_adversarial_rules(
            num_samples, tree_nodes, model, batch, shuffle, mode)
        adversarial_mask = None

        try:
            for idx in range(max(tree_num_samples)):
                values = g.ndata["values"].cpu().numpy()

                if mode == AdversarialMode.INDIVIDUAL_GRADIENT:
                    # obtain gradients w.r.t. idx-th classifiable position in each input
                    shuffle.for_next_position()
                    for rules in tree_rules:
                        for rule in rules:
                            shuffle.shuffle_candidates(rule)

                if mode == AdversarialMode.BATCH_GRADIENT_BOOSTING:
                    # mask: compute gradient only for correctly predicted positions in all previous iterations
                    g.ndata["values"] = original_values
                    adversarial_mask = model.get_adversarial_mask(
                        batch,
                        mask_field="mask_valid",
                        previous_mask=adversarial_mask)

                if mode in [
                        AdversarialMode.BATCH_GRADIENT_ASCENT,
                        AdversarialMode.BATCH_GRADIENT_BOOSTING,
                ]:
                    shuffle.initialize(model,
                                       batch,
                                       position_mask=adversarial_mask)
                    for rules in tree_rules:
                        for rule in rules:
                            shuffle.shuffle_candidates(rule)

                for ith_tree, (rules,
                               offset) in enumerate(zip(tree_rules, offsets)):
                    # one node with constant assign => one rule
                    for rule in rules:
                        # apply each rule with the batch offset of its example
                        idx_to_use = idx
                        if mode in [
                                AdversarialMode.BATCH_GRADIENT_ASCENT,
                                AdversarialMode.BATCH_GRADIENT_BOOSTING,
                        ]:
                            # values have been shuffled again, we can just use the argmax
                            idx_to_use = 0
                        rule.apply_first_valid(idx_to_use,
                                               values,
                                               usage_offset=offset)

                g.ndata["values"] = torch.tensor(values,
                                                 dtype=torch.long,
                                                 device=original_values.device)
                yield batch
        except GeneratorExit:
            return
        finally:
            g.ndata["values"] = original_values
Exemplo n.º 10
0
    def __attack_seq(
        self,
        model: NeuralModelBase,
        batch: MiniBatch,
        mask: torch.Tensor,
        num_samples,
        shuffle: ShuffleStrategy,
        mode: AdversarialMode,
    ):

        tree_sizes = batch.lengths
        if mask is not None:
            raise NotImplementedError
        else:
            tree_nodes = [None] * len(tree_sizes)

        tree_rules, tree_num_samples = self._initialize_adversarial_rules(
            num_samples, tree_nodes, model, batch, shuffle=shuffle, mode=mode)

        # the inputs should contain two tensors for types and values
        # assert len(batch.inputs) == 2
        # only values are being changed
        original_values = batch.X[-1]
        adversarial_masks = None

        try:
            for idx in range(max(tree_num_samples)):
                batch_values = batch.X[-1].cpu().t().numpy()
                idx_to_use = idx

                if mode == AdversarialMode.INDIVIDUAL_GRADIENT:
                    # obtain gradients w.r.t. idx-th classifiable position in each input
                    shuffle.for_next_position()
                    for rules in tree_rules:
                        for rule in rules:
                            shuffle.shuffle_candidates(rule)

                if mode == AdversarialMode.BATCH_GRADIENT_BOOSTING:
                    # mask: compute gradient only for correctly predicted positions in all previous iterations
                    batch.X[-1] = original_values
                    adversarial_masks = [
                        model.get_adversarial_mask(
                            minibatch,
                            mask_field="mask_valid",
                            previous_mask=adversarial_masks,
                        ) for minibatch in batch
                    ]

                if mode in [
                        AdversarialMode.BATCH_GRADIENT_ASCENT,
                        AdversarialMode.BATCH_GRADIENT_BOOSTING,
                ]:
                    shuffle.initialize(model,
                                       batch,
                                       None,
                                       position_mask=adversarial_masks)
                    for rules in tree_rules:
                        for rule in rules:
                            shuffle.shuffle_candidates(rule)
                    # values have been shuffled again, we can just use the argmax
                    idx_to_use = 0

                assert len(batch_values) == len(tree_rules)
                # for every example in batch ...
                for rules, values in zip(tree_rules, batch_values):
                    # one node with constant assign => one rule
                    for rule in rules:
                        rule.apply_first_valid(idx_to_use, values)

                batch.X[-1] = torch.tensor(
                    batch_values.transpose(),
                    dtype=torch.long,
                    device=original_values.device,
                )
                yield batch
        except GeneratorExit:
            return
        finally:
            batch.X[-1] = original_values

            for rules in tree_rules:
                for rule in rules:
                    rule.reset()
Exemplo n.º 11
0
    def fit_adversarial(
        self,
        model: NeuralModelBase,
        dataset_iter: Iterable[MiniBatch],
        adversarial_iters: Union["AdversaryBatchIter",
                                 List["AdversaryBatchIter"]],
        threshold=0.5,
    ):
        if not isinstance(adversarial_iters, list):
            adversarial_iters = [adversarial_iters]
        num_refined = 0
        num_unsound = 0
        num_imprecise = 0

        def sanity_check(orig_batch: MiniBatch, adversarial_batch: MiniBatch):
            gts = orig_batch.Y
            tmasks = model.padding_mask(orig_batch, mask_field="mask_valid")

            adv_gts = adversarial_batch.Y
            adv_tmasks = model.padding_mask(adversarial_batch,
                                            mask_field="mask_valid")
            assert torch.sum(tmasks) == torch.sum(
                adv_tmasks), "{} vs {}".format(torch.sum(tmasks),
                                               torch.sum(adv_tmasks))
            assert torch.all(gts[tmasks] == adv_gts[adv_tmasks])

        adv_stats = AdversaryAccuracyStats(self.dataset.reject_token_id)
        for batch in tqdm.tqdm(dataset_iter, ncols=100, leave=False):
            base_correct, base_preds, base_y = self.compute_correct(
                model, batch, verbose=False, threshold=threshold)

            # self.visualize_adversarial(batch, model, [base_correct], ['da_green'])
            for adversarial_iter in adversarial_iters:
                for idx, adv_batch in enumerate(
                        adversarial_iter.iter_batch(batch)):
                    sanity_check(batch, adv_batch)

                    correct, preds, adv_y = self.compute_correct(
                        model, adv_batch, verbose=False, threshold=threshold)
                    assert torch.all(base_y == adv_y)
                    unsound_mask, imprecise_mask = adv_stats.training_mask(
                        base_correct, base_preds, correct, preds)
                    # self.compute_gradients(adv_batch, model, unsound_mask, imprecise_mask, threshold)

                    mask = unsound_mask | imprecise_mask
                    num_refined += torch.sum(mask).item()
                    num_unsound += torch.sum(unsound_mask).item()
                    num_imprecise += torch.sum(imprecise_mask).item()

                    # self.visualize_adversarial(adv_batch, model, [unsound_mask, imprecise_mask], ['da_red', 'da_black'])

                    model.fit_batch(
                        model.opt,
                        model.loss_function,
                        adv_batch,
                        mask_field="mask_valid",
                        training_mask=mask,
                    )

        Logger.debug(
            "Number of refined samples: {}, unsound: {}, imprecise: {}".format(
                num_refined, num_unsound, num_imprecise))
        return num_refined