def compute_correct(self, model: NeuralModelBase, batch: MiniBatch, verbose=False, threshold=0.5): training_mask = model.padding_mask(batch) mask_valid = batch.masks["mask_valid"] Y = batch.Y predictions = model.predict_with_reject(batch, self.dataset.reject_token_id, threshold=threshold) correct = (predictions == Y) * training_mask if len(Y.size()) == 2: """ The adversary assumes that the order of predicted labels in the adversary and original batch are the same. All the modifications respect this, i.e., while code is added or removed, the order and number of predicted labels (via 'mask_valid') stays the same. However, since the sample length might change, we need to ensure that masked_select is performed batch_first (B x N rather than N x B) as otherwise the results are no guaranteed to preserve the order. """ # TODO: check that the batch is created with batch_first=False (this is the default with torchtext) mask_valid = mask_valid.t() correct = correct.t().masked_select(mask_valid) predictions = predictions.t().masked_select(mask_valid) Y = Y.t().masked_select(mask_valid) else: correct = correct.masked_select(mask_valid) predictions = predictions.masked_select(mask_valid) Y = Y.masked_select(mask_valid) return correct, predictions, Y
def get_rejection_thresholds( it, model: NeuralModelBase, dataset: Dataset, precision_thresholds: Iterable[float] ): num_bins = 1000 # stats = [SimpleNamespace(correct=0, total=0) for _ in range(num_bins + 1)] num_correct = torch.zeros(num_bins) num_total = torch.zeros(num_bins) for batch in tqdm.tqdm(it, ncols=100, leave=False): _, best_predictions, reject_probs = model.predict_probs_with_reject( batch, reject_id=dataset.reject_token_id ) mask = model.padding_mask(batch, mask_field="mask_valid") targets = batch.Y best_predictions = best_predictions.masked_select(mask) reject_probs = reject_probs.masked_select(mask).cpu() targets = targets.masked_select(mask) is_corrects = (targets == best_predictions).cpu() num_total.add_(torch.histc(reject_probs, bins=num_bins, min=0, max=1)) num_correct.add_( torch.histc( reject_probs.masked_select(is_corrects), bins=num_bins, min=0, max=1 ) ) def precision(stat): if stat.total == 0: return 0 return stat.correct * 1.0 / stat.total thresholds = [SimpleNamespace(h=None, size=0) for _ in precision_thresholds] rolling_stat = SimpleNamespace(correct=0, total=0) for i, correct, total in zip( itertools.count(), num_correct.numpy(), num_total.numpy() ): for t, precision_threshold in zip(thresholds, precision_thresholds): if precision_threshold <= precision(rolling_stat): # update threshold if it's not set or the number of samples increased if t.h is None or t.size * 1.01 < rolling_stat.total: t.h = i / float(num_bins) t.size = int(rolling_stat.total) rolling_stat.correct += correct rolling_stat.total += total Logger.debug( "Thresholds: {}, sizes: {}".format( [t.h for t in thresholds], [t.size for t in thresholds] ) ) return thresholds
def visualize_adversarial(self, batch: MiniBatch, model: NeuralModelBase, masks, colors): assert isinstance(model, GraphModel) samples = self.dataset.samples_for_batch(batch) training_mask = model.padding_mask(batch, mask_field="mask_valid") g = batch.X mask_data = {} mask_names = [] for idx, mask in enumerate(masks): mask_all = torch.zeros_like(training_mask) mask_all[training_mask] = mask mask_data[str(idx)] = mask_all mask_names.append(str(idx)) for name, mask in mask_data.items(): assert name not in g.ndata g.ndata[name] = mask trees = dgl.unbatch(g) for name in mask_data.keys(): del g.ndata[name] for tree, sample in zip(trees, samples): mask_labels = [ MaskLabel.from_values(tree.ndata[name], color) for name, color in zip(mask_names, colors) ] TreeVisualization.visualize( sample, ["types", "values"], self.dataset.dtrain.fields, labels=mask_labels + [ FieldLabel(self.dataset.dtrain.fields["types"], tree.ndata["types"]), FieldLabel(self.dataset.dtrain.fields["values"], tree.ndata["values"]), # TODO: topk # TopkLabel.from_iter([batch], sample, model, self.dataset, # mask=MaskLabel.from_sample(sample, 'mask_valid')), ], ) input() break
def print_rejection_thresholds(it, model: NeuralModelBase, dataset: Dataset): num_correct = 0 num_total = 0 thresholds = np.arange(0.1, 1.1, 0.1) stats = collections.defaultdict(lambda: SimpleNamespace(correct=0, total=0)) for batch in tqdm.tqdm(it, ncols=100, leave=False): _, best_predictions, reject_probs = model.predict_probs_with_reject( batch, reject_id=dataset.reject_token_id ) mask = model.padding_mask(batch, mask_field="mask_valid") targets = batch.Y best_predictions = best_predictions.masked_select(mask) reject_probs = reject_probs.masked_select(mask) targets = targets.masked_select(mask) is_correct = targets == best_predictions num_correct += torch.sum(is_correct).item() num_total += targets.numel() for h in thresholds: h_mask = reject_probs <= h stats[h].total += torch.sum(h_mask).item() stats[h].correct += torch.sum(is_correct.masked_select(h_mask)).item() for h in thresholds: Logger.debug( "Threshold {:5.2f}: {:6d}/{:6d} ({:.2f}%)".format( h, stats[h].correct, stats[h].total, acc(stats[h].correct, stats[h].total), ) ) Logger.debug( "{:6d}/{:6d} ({:.2f}%)".format( num_correct, num_total, acc(num_correct, num_total) ) )