Beispiel #1
0
def disbert_custom_forward(args, model, batch, train_numbers, do_eval, strategy=None, split=None):
    input_ids, attention_mask, input_values, values_bool, output_values, output_mask = batch
    batch_size, _ = input_ids.size()
    device = input_ids.device

    input_anom_values, output_fake_labels, _ = anomaly_sample(input_values, output_values, output_mask, train_numbers, 'random', True)
    output_true_labels = torch.ones_like(output_fake_labels)
    del input_values
    input_true_values = output_values
    
    if args.embed_digit:
        input_true_digits = values_to_string(input_true_values)
        input_anom_digits = values_to_string(input_anom_values)
    else:
        input_anom_digits = None
        input_true_digits = None

    fake_loss = model(input_ids, input_anom_values, values_bool, attention_mask,
        input_digits=input_anom_digits, output_values=None,
         output_mask=output_mask, output_labels=output_fake_labels)

    true_loss = model(input_ids, input_true_values, values_bool, attention_mask,
        input_digits=input_true_digits, output_values=None,
         output_mask=output_mask, output_labels=output_true_labels)

    if do_eval and split == 'test':
        if strategy == 'one':
            masked_values = torch.masked_select(output_values, output_mask.bool())
            true_exp_ids = fexp(masked_values)
            masked_ind = torch.where(output_mask == 1)
            all_scores = torch.zeros((batch_size, args.n_exponent), device=device)
            ind = 0
            for i in range(args.min_exponent, args.max_exponent):
                input_anom_values = output_values.clone()
                input_anom_values[masked_ind] = 10.0**i
                if args.embed_digit:
                    input_anom_digits = values_to_string(input_anom_values)
                else:
                    input_anom_digits = None

                _, outputs = model(input_ids, input_anom_values, values_bool, attention_mask,
                 input_digits=input_anom_digits, output_values=None,
                output_mask=output_mask, output_labels=output_true_labels, do_eval=True)
                all_scores[:,ind] = outputs['log_likelihood'].sum(dim=1)
                ind += 1

            pred_exp_ids = torch.argmax(all_scores, dim=1) #embedding
            pred_exp_ids += -1 #numberspace
            
            exp_acc = torch.sum(pred_exp_ids == true_exp_ids)
            return fake_loss, true_loss, exp_acc

    return fake_loss, true_loss
Beispiel #2
0
    def predict(self, mean_prediction_k, logvar_prediction, exponent_prediction):
        b,s,k = mean_prediction_k.size() 

        exp_ind = torch.argmax(exponent_prediction, dim=2)
        
        f_e = torch.take(self.f_e, exp_ind)
        mean_prediction = torch.gather(mean_prediction_k, 2, exp_ind.unsqueeze(dim=2)).squeeze(dim=2)
        pred_values = mean_prediction * f_e
        pred_exponent = fexp(pred_values, ignore=True)
        pred_mantissa = fman(pred_values, ignore=True)

        return pred_mantissa, pred_exponent
Beispiel #3
0
    def oracle_predict(self, output_values):
        '''pick component whose mean is closest to x'''
        b, s = output_values.size()
        means_expanded = self.means.repeat(b, s, 1)

        output_values_rpt = output_values.repeat(self.n_components, 1,
                                                 1).permute(1, 2, 0)

        diff = torch.abs(means_expanded - output_values_rpt)
        oracle_pi = torch.argmin(diff, dim=2)
        oracle_predictions = self.means[oracle_pi]

        pred_exponent = fexp(oracle_predictions, ignore=True)
        pred_mantissa = fman(oracle_predictions, ignore=True)
        return pred_mantissa, pred_exponent
Beispiel #4
0
def numeracy_metrics(output_values, output_mask, split, histograms):
    key = f"{split}_values"
    if histograms.get(key) == None:
        histograms[f"{split}_values"] = []
        histograms[f"{split}_exponents"] = []
        histograms[f"{split}_mantissas"] = []

    mask_index = (output_mask == 1).view(-1).nonzero().squeeze()
    output_values = output_values.view(-1)
    output_values = output_values[mask_index].view(-1).cpu()

    exponents = fexp(output_values)
    mantissas = fman(output_values)

    histograms[f"{split}_values"].extend(output_values.numpy())
    histograms[f"{split}_exponents"].extend(exponents.numpy())
    histograms[f"{split}_mantissas"].extend(mantissas.numpy())
    return histograms
Beispiel #5
0
def get_numbers_from_split(args, tokenizer, device, num_data_epochs, split):
    counter_nums = Counter()
    epoch = 0
    epoch_dataset = NumericalPregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer,
                                        num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory, split=split)
    train_sampler = RandomSampler(epoch_dataset)
    train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
    all_nums = []
    with tqdm(total=len(train_dataloader), desc=f"Getting {split} #s") as pbar:
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t.to(device) for t in batch)

            input_ids, attention_mask, input_values, values_bool, output_values, output_mask = batch
            batch_size, _ = input_ids.size()

            assert torch.all(output_mask.sum(dim=1) > 0)

            for i in range(batch_size):
                # todo depcrated nonzero
                mask_index = (output_mask[i]==1).nonzero().squeeze().cpu().view(-1)
                values = output_values[i][mask_index]
                values = values.tolist()
                all_nums.extend(values)
                counter_nums[len(values)] += 1


    print('Distribution of numbers per datum', counter_nums)
    all_nums = np.array(all_nums)
    
    print(f'min:{np.min(all_nums)}, max:{np.max(all_nums)}, mean: {np.mean(all_nums)}')
    print(f'percentile, 25:{np.percentile(all_nums, 25)}, 50:{np.percentile(all_nums, 50)}, 75:{np.percentile(all_nums, 75)}')
    
    exp_counter = Counter()
    tensor_nums = torch.tensor(all_nums, dtype=torch.float)
    all_exps = fexp(tensor_nums)
    all_exps = all_exps.numpy()
    exp_counter.update(all_exps)
    print('exp_counter', exp_counter)

    return all_nums
Beispiel #6
0
 def predict(self, mu_pred, **kwargs):
     x_pred = self.f_forward(mu_pred, **kwargs)
     pred_exponent = fexp(x_pred, ignore=True)
     pred_mantissa = fman(x_pred, ignore=True)
     return pred_mantissa, pred_exponent
Beispiel #7
0
 def predict(self, logits):
     ind = torch.argmax(logits, dim=2)
     pred_values = self.means[ind]
     pred_exponent = fexp(pred_values, ignore=True)
     pred_mantissa = fman(pred_values, ignore=True)
     return pred_mantissa, pred_exponent
Beispiel #8
0
def evaluation(args, model, tokenizer, device, global_step, split='valid', train_mean=None, train_median=None, train_numbers=None):
    all_metrics = {}
    num_data_epochs = args.epochs
    epoch = 0
    if split == 'train':
        strategies = ['']
    else:
        strategies = ['one', 'all']
    
    for strategy in strategies:
        epoch_dataset = NumericalPregeneratedDataset(epoch=epoch,
                         training_path=args.pregenerated_data, tokenizer=tokenizer,
                         num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory,
                         split=split, strategy=strategy)

        train_sampler = SequentialSampler(epoch_dataset)
        eval_metrics = {}
        histograms = {}
        
        with torch.set_grad_enabled(False):
            if args.do_anomaly and strategy == 'one':
                train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.eval_batch_size)
                options = ['random', 'string']
                for option in options:
                    print('anomaly', option)
                    eval_metrics = anomaly_evaluation(args, model, device, tokenizer,
                     train_dataloader, eval_metrics, '', train_numbers, option)

        train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.eval_batch_size)
        nb_eval_examples = 0.0
        with torch.set_grad_enabled(False):
            with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar:
                
                for step, batch in enumerate(train_dataloader):

                    batch = tuple(t.to(device) for t in batch)
                    input_ids, attention_mask, input_values, values_bool, output_values, output_mask = batch

                    if args.embed_digit:
                        input_digits = values_to_string(input_values)
                    else:
                        input_digits = None

                    batch_size = input_ids.size(0)
                    histograms = numeracy_metrics(output_values, output_mask, split, histograms)

                    if split == 'valid' or split == 'train' or split == 'test':
                        torch.cuda.empty_cache()
                        loss, outputs = model(input_ids, input_values, values_bool, attention_mask,
                            input_digits=input_digits, output_values=output_values,
                            output_mask=output_mask, do_eval=True)

                        pred_mantissa, pred_exponent = outputs['pred_mantissa'], outputs['pred_exponent']
                        true_mantissa, true_exponent = fman(output_values), fexp(output_values)

                        eval_metrics = loss_metrics(loss, eval_metrics)
                        
                        if args.do_log:
                            metric_value = outputs['flow_mu_pred']
                            eval_metrics = log_metrics(eval_metrics, metric_value, output_mask, mode='')

                        if args.do_flow:
                            flow_items = {k:v for (k,v) in outputs.items() if k.startswith('flow') }
                            eval_metrics = flow_metrics(eval_metrics, flow_items, output_mask, args.flow_v, mode='')

        
                        if args.do_gmm:
                            oracle_mantissa, oracle_exponent = model.oracle_predict(output_values)
                            eval_metrics = mantissa_metrics(true_mantissa, oracle_mantissa, output_mask, eval_metrics, 'oracle_')
                            eval_metrics, histograms = exponent_metrics(true_exponent, oracle_exponent, output_mask, eval_metrics, histograms, 'oracle_')
                            eval_metrics, _, _ = regression_metrics(true_mantissa, oracle_mantissa,
                                            true_exponent, oracle_exponent,
                                            output_mask, output_values, eval_metrics, 'oracle_')

                        if train_mean is not None:
                            train_means = torch.zeros_like(true_mantissa) +train_mean
                            train_mean_exponents = fexp(train_means)
                            train_mean_mantissas = fman(train_means)
                            eval_metrics, histograms = exponent_metrics(true_exponent, train_mean_exponents, output_mask, eval_metrics, histograms, 'mean_')
                            eval_metrics, _, _ = regression_metrics(true_mantissa, train_mean_mantissas,
                                            true_exponent, train_mean_exponents,
                                            output_mask, output_values, eval_metrics, 'mean_')

                            train_medians = torch.zeros_like(true_mantissa) +train_median
                            train_median_exponents = fexp(train_medians)
                            train_median_mantissas = fman(train_medians)
                            eval_metrics, histograms = exponent_metrics(true_exponent, train_median_exponents, output_mask, eval_metrics, histograms, 'median_')
                            eval_metrics, _, _ = regression_metrics(true_mantissa, train_median_mantissas,
                                            true_exponent, train_median_exponents,
                                            output_mask, output_values, eval_metrics, 'median_')

                        eval_metrics = mantissa_metrics(true_mantissa, pred_mantissa, output_mask, eval_metrics)
                        eval_metrics, histograms = exponent_metrics(true_exponent, pred_exponent, output_mask, eval_metrics, histograms)
                        eval_metrics, true_numbers, pred_numbers = regression_metrics(true_mantissa, pred_mantissa,
                                            true_exponent, pred_exponent,
                                            output_mask, output_values, eval_metrics)

                    nb_eval_examples += torch.sum(output_mask).float().item()


                if split == 'valid' or split == 'train' or split == 'test':
                    if strategy != '':
                        prefix = f'{split}_{strategy}'
                    else:           
                        prefix = f'{split}'
                    summary = summarize_metrics(eval_metrics, nb_eval_examples, prefix)
                    all_metrics.update(summary)
                    log_wandb(summary, global_step)
    
    return all_metrics
Beispiel #9
0
def anomaly_sample(input_values, output_values, output_mask, train_numbers,
                   mode, is_disbert):
    device = output_values.device
    b, s = output_values.size()
    if mode == 'random':
        random_numbers = torch.tensor(np.random.choice(train_numbers, b * s),
                                      dtype=torch.float,
                                      device=device)
        random_numbers = random_numbers.view(b, s)
    elif mode == 'sample':
        mean = np.mean(train_numbers)
        std = np.std(train_numbers)
        low = 0.1
        upp = 10**16 - 1.0
        cutoff_norm = truncnorm((low - mean) / std, (upp - mean) / std,
                                loc=mean,
                                scale=std)
        random_numbers = torch.tensor(cutoff_norm.rvs(b * s),
                                      dtype=torch.float,
                                      device=device)
        random_numbers = random_numbers.view(b, s)
    elif mode == 'string':
        opt = np.random.choice(3)
        np_intermed = v_np_str(output_values.cpu().numpy())
        if opt == 0:
            np_intermed = v_np_add(np_intermed)
        elif opt == 1:
            np_intermed = v_np_del(np_intermed)
        elif opt == 2:
            np_intermed = v_np_swap(np_intermed)
        np_intermed = v_np_float(np_intermed)
        random_numbers = torch.tensor(np_intermed,
                                      dtype=torch.float,
                                      device=device)
    elif mode == 'add':
        np_intermed = v_np_str(output_values.cpu().numpy())
        np_intermed = v_np_add(np_intermed)
        np_intermed = v_np_float(np_intermed)
        random_numbers = torch.tensor(np_intermed,
                                      dtype=torch.float,
                                      device=device)
    elif mode == 'swap':
        #swap the first 2
        np_intermed = v_np_str(output_values.cpu().numpy())
        np_intermed = v_np_swap(np_intermed)
        np_intermed = v_np_float(np_intermed)
        random_numbers = torch.tensor(np_intermed,
                                      dtype=torch.float,
                                      device=device)

    output_fake_labels = torch.zeros(b, s, device=device)

    true_values = torch.masked_select(output_values, output_mask.bool())
    fake_values = torch.masked_select(random_numbers, output_mask.bool())
    true_exp_ids = fexp(true_values)
    fake_exp_ids = fexp(fake_values)
    oracle_auc = torch.sum(true_exp_ids == fake_exp_ids)

    if is_disbert:
        input_anom_numbers = input_values * (
            1 - output_mask.float()) + output_mask.float() * random_numbers
        return input_anom_numbers, output_fake_labels, oracle_auc
    else:
        output_anom_numbers = output_values * (
            1 - output_mask.float()) + output_mask.float() * random_numbers
        return output_anom_numbers, output_fake_labels, oracle_auc