def compute_correct(self, model: NeuralModelBase, batch: MiniBatch, verbose=False, threshold=0.5): training_mask = model.padding_mask(batch) mask_valid = batch.masks["mask_valid"] Y = batch.Y predictions = model.predict_with_reject(batch, self.dataset.reject_token_id, threshold=threshold) correct = (predictions == Y) * training_mask if len(Y.size()) == 2: """ The adversary assumes that the order of predicted labels in the adversary and original batch are the same. All the modifications respect this, i.e., while code is added or removed, the order and number of predicted labels (via 'mask_valid') stays the same. However, since the sample length might change, we need to ensure that masked_select is performed batch_first (B x N rather than N x B) as otherwise the results are no guaranteed to preserve the order. """ # TODO: check that the batch is created with batch_first=False (this is the default with torchtext) mask_valid = mask_valid.t() correct = correct.t().masked_select(mask_valid) predictions = predictions.t().masked_select(mask_valid) Y = Y.t().masked_select(mask_valid) else: correct = correct.masked_select(mask_valid) predictions = predictions.masked_select(mask_valid) Y = Y.masked_select(mask_valid) return correct, predictions, Y
def get_rejection_thresholds( it, model: NeuralModelBase, dataset: Dataset, precision_thresholds: Iterable[float] ): num_bins = 1000 # stats = [SimpleNamespace(correct=0, total=0) for _ in range(num_bins + 1)] num_correct = torch.zeros(num_bins) num_total = torch.zeros(num_bins) for batch in tqdm.tqdm(it, ncols=100, leave=False): _, best_predictions, reject_probs = model.predict_probs_with_reject( batch, reject_id=dataset.reject_token_id ) mask = model.padding_mask(batch, mask_field="mask_valid") targets = batch.Y best_predictions = best_predictions.masked_select(mask) reject_probs = reject_probs.masked_select(mask).cpu() targets = targets.masked_select(mask) is_corrects = (targets == best_predictions).cpu() num_total.add_(torch.histc(reject_probs, bins=num_bins, min=0, max=1)) num_correct.add_( torch.histc( reject_probs.masked_select(is_corrects), bins=num_bins, min=0, max=1 ) ) def precision(stat): if stat.total == 0: return 0 return stat.correct * 1.0 / stat.total thresholds = [SimpleNamespace(h=None, size=0) for _ in precision_thresholds] rolling_stat = SimpleNamespace(correct=0, total=0) for i, correct, total in zip( itertools.count(), num_correct.numpy(), num_total.numpy() ): for t, precision_threshold in zip(thresholds, precision_thresholds): if precision_threshold <= precision(rolling_stat): # update threshold if it's not set or the number of samples increased if t.h is None or t.size * 1.01 < rolling_stat.total: t.h = i / float(num_bins) t.size = int(rolling_stat.total) rolling_stat.correct += correct rolling_stat.total += total Logger.debug( "Thresholds: {}, sizes: {}".format( [t.h for t in thresholds], [t.size for t in thresholds] ) ) return thresholds
def load_model(model: NeuralModelBase, args, model_id): import torch checkpoint_file = os.path.join(checkpoint_dir(args), checkpoint_name(args, model_id)) print("checkpoint_file", checkpoint_file) if not os.path.exists(checkpoint_file): return False Logger.debug("Loading model from {}".format(checkpoint_file)) data = torch.load(checkpoint_file) model.load_state_dict(data) return True
def save_model(model: NeuralModelBase, args, model_id): import torch checkpoint_file = os.path.join(checkpoint_dir(args), checkpoint_name(args, model_id)) Logger.debug("Saving model to {}".format(checkpoint_file)) torch.save(model.state_dict(), checkpoint_file)
def train_base_model( model: NeuralModelBase, dataset: Dataset, num_epochs, train_iter, valid_iter, lr=0.001, verbose=True, ): valid_iters = [valid_iter ] if not isinstance(valid_iter, list) else valid_iter Logger.start_scope("Training Model") opt = optim.Adam(model.parameters(), lr=lr) model.opt = opt loss_function = nn.CrossEntropyLoss(reduction="none") model.loss_function = loss_function train_prec, valid_prec = None, None for epoch in range(num_epochs): Logger.start_scope("Epoch {}".format(epoch)) model.fit(train_iter, opt, loss_function, mask_field="mask_valid") for valid_iter in valid_iters: valid_stats = model.accuracy(valid_iter, dataset.TARGET, verbose=verbose) valid_prec = valid_stats["mask_valid_noreject_acc"] Logger.debug(f"valid_prec: {valid_prec}") Logger.end_scope() train_stats = model.accuracy(train_iter, dataset.TARGET, verbose=False) train_prec = train_stats["mask_valid_noreject_acc"] Logger.debug(f"train_prec: {train_prec}, valid_prec: {valid_prec}") Logger.end_scope() return train_prec, valid_prec
def print_rejection_thresholds(it, model: NeuralModelBase, dataset: Dataset): num_correct = 0 num_total = 0 thresholds = np.arange(0.1, 1.1, 0.1) stats = collections.defaultdict(lambda: SimpleNamespace(correct=0, total=0)) for batch in tqdm.tqdm(it, ncols=100, leave=False): _, best_predictions, reject_probs = model.predict_probs_with_reject( batch, reject_id=dataset.reject_token_id ) mask = model.padding_mask(batch, mask_field="mask_valid") targets = batch.Y best_predictions = best_predictions.masked_select(mask) reject_probs = reject_probs.masked_select(mask) targets = targets.masked_select(mask) is_correct = targets == best_predictions num_correct += torch.sum(is_correct).item() num_total += targets.numel() for h in thresholds: h_mask = reject_probs <= h stats[h].total += torch.sum(h_mask).item() stats[h].correct += torch.sum(is_correct.masked_select(h_mask)).item() for h in thresholds: Logger.debug( "Threshold {:5.2f}: {:6d}/{:6d} ({:.2f}%)".format( h, stats[h].correct, stats[h].total, acc(stats[h].correct, stats[h].total), ) ) Logger.debug( "{:6d}/{:6d} ({:.2f}%)".format( num_correct, num_total, acc(num_correct, num_total) ) )
def visualize_adversarial(self, batch: MiniBatch, model: NeuralModelBase, masks, colors): assert isinstance(model, GraphModel) samples = self.dataset.samples_for_batch(batch) training_mask = model.padding_mask(batch, mask_field="mask_valid") g = batch.X mask_data = {} mask_names = [] for idx, mask in enumerate(masks): mask_all = torch.zeros_like(training_mask) mask_all[training_mask] = mask mask_data[str(idx)] = mask_all mask_names.append(str(idx)) for name, mask in mask_data.items(): assert name not in g.ndata g.ndata[name] = mask trees = dgl.unbatch(g) for name in mask_data.keys(): del g.ndata[name] for tree, sample in zip(trees, samples): mask_labels = [ MaskLabel.from_values(tree.ndata[name], color) for name, color in zip(mask_names, colors) ] TreeVisualization.visualize( sample, ["types", "values"], self.dataset.dtrain.fields, labels=mask_labels + [ FieldLabel(self.dataset.dtrain.fields["types"], tree.ndata["types"]), FieldLabel(self.dataset.dtrain.fields["values"], tree.ndata["values"]), # TODO: topk # TopkLabel.from_iter([batch], sample, model, self.dataset, # mask=MaskLabel.from_sample(sample, 'mask_valid')), ], ) input() break
def train_model( model: NeuralModelBase, dataset: Dataset, num_epochs, train_iter, valid_iter, lr=0.001, weight=None, target_o=1.0, ): # model.reset_parameters() opt = optim.Adam(model.parameters(), lr=lr) Logger.start_scope("Training Model") o_base = len(dataset.TARGET.vocab) - 4 # 'reject', '<unk>', '<pad>' loss_function = RejectionCrossEntropyLoss( o_base, len(dataset.TARGET.vocab), dataset.reject_token_id, reduction="none", weight=weight, ) model.loss_function = loss_function model.opt = opt step = 1.0 / (num_epochs // 2) schedule = [ f * o_base + (1 - f) * 1.0 for f in np.arange(start=1.0, stop=0.0, step=-step) ] schedule += [ f * ((1.0 + schedule[-1]) / 2) + (1 - f) * target_o for f in np.arange(start=1.0, stop=0.0, step=-step) ] schedule += [target_o] * (num_epochs // 2) train_prec, valid_prec = None, None for epoch, o_upper in enumerate(schedule): Logger.start_scope("Epoch {}, o_upper={:.3f}".format(epoch, o_upper)) loss_function.o = o_upper model.fit(train_iter, opt, loss_function, mask_field="mask_valid") valid_stats = model.accuracy( valid_iter, dataset.TARGET ) # , thresholds=[0.5, 0.8, 0.9, 0.95]) valid_prec = valid_stats["mask_valid_noreject_acc"] Logger.debug(f"valid_prec: {valid_prec}") Logger.end_scope() # Logger.start_scope('Print Rejection Thresholds') # print_rejection_thresholds(train_iter, model, dataset) # print_rejection_thresholds(valid_iter, model, dataset) # Logger.end_scope() # Logger.start_scope('Get Rejection Thresholds') # get_rejection_thresholds(train_iter, model, dataset, [1.00, 0.99, 0.95, 0.9, 0.8]) # get_rejection_thresholds(valid_iter, model, dataset, [1.00, 0.99, 0.95, 0.9, 0.8]) # Logger.end_scope() train_stats = model.accuracy(train_iter, dataset.TARGET, verbose=False) train_prec = train_stats["mask_valid_noreject_acc"] Logger.debug(f"train_prec: {train_prec}, valid_prec: {valid_prec}") Logger.end_scope() # exit(0) return train_prec, valid_prec
def __attack_graph( self, model: NeuralModelBase, batch: MiniBatch, mask: torch.Tensor, num_samples, shuffle: ShuffleStrategy, mode: AdversarialMode, ): tree_sizes = batch.lengths offsets = (np.cumsum(tree_sizes) - np.array(tree_sizes)).tolist() if mask is not None: tree_masks = torch.split(mask, tree_sizes) tree_nodes = [ np.flatnonzero(tree_mask.cpu().numpy()) for tree_mask in tree_masks ] else: tree_nodes = [None] * len(tree_sizes) g = batch.X assert g.number_of_nodes() == sum(tree_sizes) original_values = g.ndata["values"] tree_rules, tree_num_samples = self._initialize_adversarial_rules( num_samples, tree_nodes, model, batch, shuffle, mode) adversarial_mask = None try: for idx in range(max(tree_num_samples)): values = g.ndata["values"].cpu().numpy() if mode == AdversarialMode.INDIVIDUAL_GRADIENT: # obtain gradients w.r.t. idx-th classifiable position in each input shuffle.for_next_position() for rules in tree_rules: for rule in rules: shuffle.shuffle_candidates(rule) if mode == AdversarialMode.BATCH_GRADIENT_BOOSTING: # mask: compute gradient only for correctly predicted positions in all previous iterations g.ndata["values"] = original_values adversarial_mask = model.get_adversarial_mask( batch, mask_field="mask_valid", previous_mask=adversarial_mask) if mode in [ AdversarialMode.BATCH_GRADIENT_ASCENT, AdversarialMode.BATCH_GRADIENT_BOOSTING, ]: shuffle.initialize(model, batch, position_mask=adversarial_mask) for rules in tree_rules: for rule in rules: shuffle.shuffle_candidates(rule) for ith_tree, (rules, offset) in enumerate(zip(tree_rules, offsets)): # one node with constant assign => one rule for rule in rules: # apply each rule with the batch offset of its example idx_to_use = idx if mode in [ AdversarialMode.BATCH_GRADIENT_ASCENT, AdversarialMode.BATCH_GRADIENT_BOOSTING, ]: # values have been shuffled again, we can just use the argmax idx_to_use = 0 rule.apply_first_valid(idx_to_use, values, usage_offset=offset) g.ndata["values"] = torch.tensor(values, dtype=torch.long, device=original_values.device) yield batch except GeneratorExit: return finally: g.ndata["values"] = original_values
def __attack_seq( self, model: NeuralModelBase, batch: MiniBatch, mask: torch.Tensor, num_samples, shuffle: ShuffleStrategy, mode: AdversarialMode, ): tree_sizes = batch.lengths if mask is not None: raise NotImplementedError else: tree_nodes = [None] * len(tree_sizes) tree_rules, tree_num_samples = self._initialize_adversarial_rules( num_samples, tree_nodes, model, batch, shuffle=shuffle, mode=mode) # the inputs should contain two tensors for types and values # assert len(batch.inputs) == 2 # only values are being changed original_values = batch.X[-1] adversarial_masks = None try: for idx in range(max(tree_num_samples)): batch_values = batch.X[-1].cpu().t().numpy() idx_to_use = idx if mode == AdversarialMode.INDIVIDUAL_GRADIENT: # obtain gradients w.r.t. idx-th classifiable position in each input shuffle.for_next_position() for rules in tree_rules: for rule in rules: shuffle.shuffle_candidates(rule) if mode == AdversarialMode.BATCH_GRADIENT_BOOSTING: # mask: compute gradient only for correctly predicted positions in all previous iterations batch.X[-1] = original_values adversarial_masks = [ model.get_adversarial_mask( minibatch, mask_field="mask_valid", previous_mask=adversarial_masks, ) for minibatch in batch ] if mode in [ AdversarialMode.BATCH_GRADIENT_ASCENT, AdversarialMode.BATCH_GRADIENT_BOOSTING, ]: shuffle.initialize(model, batch, None, position_mask=adversarial_masks) for rules in tree_rules: for rule in rules: shuffle.shuffle_candidates(rule) # values have been shuffled again, we can just use the argmax idx_to_use = 0 assert len(batch_values) == len(tree_rules) # for every example in batch ... for rules, values in zip(tree_rules, batch_values): # one node with constant assign => one rule for rule in rules: rule.apply_first_valid(idx_to_use, values) batch.X[-1] = torch.tensor( batch_values.transpose(), dtype=torch.long, device=original_values.device, ) yield batch except GeneratorExit: return finally: batch.X[-1] = original_values for rules in tree_rules: for rule in rules: rule.reset()
def fit_adversarial( self, model: NeuralModelBase, dataset_iter: Iterable[MiniBatch], adversarial_iters: Union["AdversaryBatchIter", List["AdversaryBatchIter"]], threshold=0.5, ): if not isinstance(adversarial_iters, list): adversarial_iters = [adversarial_iters] num_refined = 0 num_unsound = 0 num_imprecise = 0 def sanity_check(orig_batch: MiniBatch, adversarial_batch: MiniBatch): gts = orig_batch.Y tmasks = model.padding_mask(orig_batch, mask_field="mask_valid") adv_gts = adversarial_batch.Y adv_tmasks = model.padding_mask(adversarial_batch, mask_field="mask_valid") assert torch.sum(tmasks) == torch.sum( adv_tmasks), "{} vs {}".format(torch.sum(tmasks), torch.sum(adv_tmasks)) assert torch.all(gts[tmasks] == adv_gts[adv_tmasks]) adv_stats = AdversaryAccuracyStats(self.dataset.reject_token_id) for batch in tqdm.tqdm(dataset_iter, ncols=100, leave=False): base_correct, base_preds, base_y = self.compute_correct( model, batch, verbose=False, threshold=threshold) # self.visualize_adversarial(batch, model, [base_correct], ['da_green']) for adversarial_iter in adversarial_iters: for idx, adv_batch in enumerate( adversarial_iter.iter_batch(batch)): sanity_check(batch, adv_batch) correct, preds, adv_y = self.compute_correct( model, adv_batch, verbose=False, threshold=threshold) assert torch.all(base_y == adv_y) unsound_mask, imprecise_mask = adv_stats.training_mask( base_correct, base_preds, correct, preds) # self.compute_gradients(adv_batch, model, unsound_mask, imprecise_mask, threshold) mask = unsound_mask | imprecise_mask num_refined += torch.sum(mask).item() num_unsound += torch.sum(unsound_mask).item() num_imprecise += torch.sum(imprecise_mask).item() # self.visualize_adversarial(adv_batch, model, [unsound_mask, imprecise_mask], ['da_red', 'da_black']) model.fit_batch( model.opt, model.loss_function, adv_batch, mask_field="mask_valid", training_mask=mask, ) Logger.debug( "Number of refined samples: {}, unsound: {}, imprecise: {}".format( num_refined, num_unsound, num_imprecise)) return num_refined