def disbert_custom_forward(args, model, batch, train_numbers, do_eval, strategy=None, split=None): input_ids, attention_mask, input_values, values_bool, output_values, output_mask = batch batch_size, _ = input_ids.size() device = input_ids.device input_anom_values, output_fake_labels, _ = anomaly_sample(input_values, output_values, output_mask, train_numbers, 'random', True) output_true_labels = torch.ones_like(output_fake_labels) del input_values input_true_values = output_values if args.embed_digit: input_true_digits = values_to_string(input_true_values) input_anom_digits = values_to_string(input_anom_values) else: input_anom_digits = None input_true_digits = None fake_loss = model(input_ids, input_anom_values, values_bool, attention_mask, input_digits=input_anom_digits, output_values=None, output_mask=output_mask, output_labels=output_fake_labels) true_loss = model(input_ids, input_true_values, values_bool, attention_mask, input_digits=input_true_digits, output_values=None, output_mask=output_mask, output_labels=output_true_labels) if do_eval and split == 'test': if strategy == 'one': masked_values = torch.masked_select(output_values, output_mask.bool()) true_exp_ids = fexp(masked_values) masked_ind = torch.where(output_mask == 1) all_scores = torch.zeros((batch_size, args.n_exponent), device=device) ind = 0 for i in range(args.min_exponent, args.max_exponent): input_anom_values = output_values.clone() input_anom_values[masked_ind] = 10.0**i if args.embed_digit: input_anom_digits = values_to_string(input_anom_values) else: input_anom_digits = None _, outputs = model(input_ids, input_anom_values, values_bool, attention_mask, input_digits=input_anom_digits, output_values=None, output_mask=output_mask, output_labels=output_true_labels, do_eval=True) all_scores[:,ind] = outputs['log_likelihood'].sum(dim=1) ind += 1 pred_exp_ids = torch.argmax(all_scores, dim=1) #embedding pred_exp_ids += -1 #numberspace exp_acc = torch.sum(pred_exp_ids == true_exp_ids) return fake_loss, true_loss, exp_acc return fake_loss, true_loss
def predict(self, mean_prediction_k, logvar_prediction, exponent_prediction): b,s,k = mean_prediction_k.size() exp_ind = torch.argmax(exponent_prediction, dim=2) f_e = torch.take(self.f_e, exp_ind) mean_prediction = torch.gather(mean_prediction_k, 2, exp_ind.unsqueeze(dim=2)).squeeze(dim=2) pred_values = mean_prediction * f_e pred_exponent = fexp(pred_values, ignore=True) pred_mantissa = fman(pred_values, ignore=True) return pred_mantissa, pred_exponent
def oracle_predict(self, output_values): '''pick component whose mean is closest to x''' b, s = output_values.size() means_expanded = self.means.repeat(b, s, 1) output_values_rpt = output_values.repeat(self.n_components, 1, 1).permute(1, 2, 0) diff = torch.abs(means_expanded - output_values_rpt) oracle_pi = torch.argmin(diff, dim=2) oracle_predictions = self.means[oracle_pi] pred_exponent = fexp(oracle_predictions, ignore=True) pred_mantissa = fman(oracle_predictions, ignore=True) return pred_mantissa, pred_exponent
def numeracy_metrics(output_values, output_mask, split, histograms): key = f"{split}_values" if histograms.get(key) == None: histograms[f"{split}_values"] = [] histograms[f"{split}_exponents"] = [] histograms[f"{split}_mantissas"] = [] mask_index = (output_mask == 1).view(-1).nonzero().squeeze() output_values = output_values.view(-1) output_values = output_values[mask_index].view(-1).cpu() exponents = fexp(output_values) mantissas = fman(output_values) histograms[f"{split}_values"].extend(output_values.numpy()) histograms[f"{split}_exponents"].extend(exponents.numpy()) histograms[f"{split}_mantissas"].extend(mantissas.numpy()) return histograms
def get_numbers_from_split(args, tokenizer, device, num_data_epochs, split): counter_nums = Counter() epoch = 0 epoch_dataset = NumericalPregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer, num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory, split=split) train_sampler = RandomSampler(epoch_dataset) train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size) all_nums = [] with tqdm(total=len(train_dataloader), desc=f"Getting {split} #s") as pbar: for step, batch in enumerate(train_dataloader): batch = tuple(t.to(device) for t in batch) input_ids, attention_mask, input_values, values_bool, output_values, output_mask = batch batch_size, _ = input_ids.size() assert torch.all(output_mask.sum(dim=1) > 0) for i in range(batch_size): # todo depcrated nonzero mask_index = (output_mask[i]==1).nonzero().squeeze().cpu().view(-1) values = output_values[i][mask_index] values = values.tolist() all_nums.extend(values) counter_nums[len(values)] += 1 print('Distribution of numbers per datum', counter_nums) all_nums = np.array(all_nums) print(f'min:{np.min(all_nums)}, max:{np.max(all_nums)}, mean: {np.mean(all_nums)}') print(f'percentile, 25:{np.percentile(all_nums, 25)}, 50:{np.percentile(all_nums, 50)}, 75:{np.percentile(all_nums, 75)}') exp_counter = Counter() tensor_nums = torch.tensor(all_nums, dtype=torch.float) all_exps = fexp(tensor_nums) all_exps = all_exps.numpy() exp_counter.update(all_exps) print('exp_counter', exp_counter) return all_nums
def predict(self, mu_pred, **kwargs): x_pred = self.f_forward(mu_pred, **kwargs) pred_exponent = fexp(x_pred, ignore=True) pred_mantissa = fman(x_pred, ignore=True) return pred_mantissa, pred_exponent
def predict(self, logits): ind = torch.argmax(logits, dim=2) pred_values = self.means[ind] pred_exponent = fexp(pred_values, ignore=True) pred_mantissa = fman(pred_values, ignore=True) return pred_mantissa, pred_exponent
def evaluation(args, model, tokenizer, device, global_step, split='valid', train_mean=None, train_median=None, train_numbers=None): all_metrics = {} num_data_epochs = args.epochs epoch = 0 if split == 'train': strategies = [''] else: strategies = ['one', 'all'] for strategy in strategies: epoch_dataset = NumericalPregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer, num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory, split=split, strategy=strategy) train_sampler = SequentialSampler(epoch_dataset) eval_metrics = {} histograms = {} with torch.set_grad_enabled(False): if args.do_anomaly and strategy == 'one': train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.eval_batch_size) options = ['random', 'string'] for option in options: print('anomaly', option) eval_metrics = anomaly_evaluation(args, model, device, tokenizer, train_dataloader, eval_metrics, '', train_numbers, option) train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.eval_batch_size) nb_eval_examples = 0.0 with torch.set_grad_enabled(False): with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar: for step, batch in enumerate(train_dataloader): batch = tuple(t.to(device) for t in batch) input_ids, attention_mask, input_values, values_bool, output_values, output_mask = batch if args.embed_digit: input_digits = values_to_string(input_values) else: input_digits = None batch_size = input_ids.size(0) histograms = numeracy_metrics(output_values, output_mask, split, histograms) if split == 'valid' or split == 'train' or split == 'test': torch.cuda.empty_cache() loss, outputs = model(input_ids, input_values, values_bool, attention_mask, input_digits=input_digits, output_values=output_values, output_mask=output_mask, do_eval=True) pred_mantissa, pred_exponent = outputs['pred_mantissa'], outputs['pred_exponent'] true_mantissa, true_exponent = fman(output_values), fexp(output_values) eval_metrics = loss_metrics(loss, eval_metrics) if args.do_log: metric_value = outputs['flow_mu_pred'] eval_metrics = log_metrics(eval_metrics, metric_value, output_mask, mode='') if args.do_flow: flow_items = {k:v for (k,v) in outputs.items() if k.startswith('flow') } eval_metrics = flow_metrics(eval_metrics, flow_items, output_mask, args.flow_v, mode='') if args.do_gmm: oracle_mantissa, oracle_exponent = model.oracle_predict(output_values) eval_metrics = mantissa_metrics(true_mantissa, oracle_mantissa, output_mask, eval_metrics, 'oracle_') eval_metrics, histograms = exponent_metrics(true_exponent, oracle_exponent, output_mask, eval_metrics, histograms, 'oracle_') eval_metrics, _, _ = regression_metrics(true_mantissa, oracle_mantissa, true_exponent, oracle_exponent, output_mask, output_values, eval_metrics, 'oracle_') if train_mean is not None: train_means = torch.zeros_like(true_mantissa) +train_mean train_mean_exponents = fexp(train_means) train_mean_mantissas = fman(train_means) eval_metrics, histograms = exponent_metrics(true_exponent, train_mean_exponents, output_mask, eval_metrics, histograms, 'mean_') eval_metrics, _, _ = regression_metrics(true_mantissa, train_mean_mantissas, true_exponent, train_mean_exponents, output_mask, output_values, eval_metrics, 'mean_') train_medians = torch.zeros_like(true_mantissa) +train_median train_median_exponents = fexp(train_medians) train_median_mantissas = fman(train_medians) eval_metrics, histograms = exponent_metrics(true_exponent, train_median_exponents, output_mask, eval_metrics, histograms, 'median_') eval_metrics, _, _ = regression_metrics(true_mantissa, train_median_mantissas, true_exponent, train_median_exponents, output_mask, output_values, eval_metrics, 'median_') eval_metrics = mantissa_metrics(true_mantissa, pred_mantissa, output_mask, eval_metrics) eval_metrics, histograms = exponent_metrics(true_exponent, pred_exponent, output_mask, eval_metrics, histograms) eval_metrics, true_numbers, pred_numbers = regression_metrics(true_mantissa, pred_mantissa, true_exponent, pred_exponent, output_mask, output_values, eval_metrics) nb_eval_examples += torch.sum(output_mask).float().item() if split == 'valid' or split == 'train' or split == 'test': if strategy != '': prefix = f'{split}_{strategy}' else: prefix = f'{split}' summary = summarize_metrics(eval_metrics, nb_eval_examples, prefix) all_metrics.update(summary) log_wandb(summary, global_step) return all_metrics
def anomaly_sample(input_values, output_values, output_mask, train_numbers, mode, is_disbert): device = output_values.device b, s = output_values.size() if mode == 'random': random_numbers = torch.tensor(np.random.choice(train_numbers, b * s), dtype=torch.float, device=device) random_numbers = random_numbers.view(b, s) elif mode == 'sample': mean = np.mean(train_numbers) std = np.std(train_numbers) low = 0.1 upp = 10**16 - 1.0 cutoff_norm = truncnorm((low - mean) / std, (upp - mean) / std, loc=mean, scale=std) random_numbers = torch.tensor(cutoff_norm.rvs(b * s), dtype=torch.float, device=device) random_numbers = random_numbers.view(b, s) elif mode == 'string': opt = np.random.choice(3) np_intermed = v_np_str(output_values.cpu().numpy()) if opt == 0: np_intermed = v_np_add(np_intermed) elif opt == 1: np_intermed = v_np_del(np_intermed) elif opt == 2: np_intermed = v_np_swap(np_intermed) np_intermed = v_np_float(np_intermed) random_numbers = torch.tensor(np_intermed, dtype=torch.float, device=device) elif mode == 'add': np_intermed = v_np_str(output_values.cpu().numpy()) np_intermed = v_np_add(np_intermed) np_intermed = v_np_float(np_intermed) random_numbers = torch.tensor(np_intermed, dtype=torch.float, device=device) elif mode == 'swap': #swap the first 2 np_intermed = v_np_str(output_values.cpu().numpy()) np_intermed = v_np_swap(np_intermed) np_intermed = v_np_float(np_intermed) random_numbers = torch.tensor(np_intermed, dtype=torch.float, device=device) output_fake_labels = torch.zeros(b, s, device=device) true_values = torch.masked_select(output_values, output_mask.bool()) fake_values = torch.masked_select(random_numbers, output_mask.bool()) true_exp_ids = fexp(true_values) fake_exp_ids = fexp(fake_values) oracle_auc = torch.sum(true_exp_ids == fake_exp_ids) if is_disbert: input_anom_numbers = input_values * ( 1 - output_mask.float()) + output_mask.float() * random_numbers return input_anom_numbers, output_fake_labels, oracle_auc else: output_anom_numbers = output_values * ( 1 - output_mask.float()) + output_mask.float() * random_numbers return output_anom_numbers, output_fake_labels, oracle_auc