Exemplo n.º 1
0
    def evaluate(self, data_loader, data_dict):
        device = self.device
        self.Experts.eval()
        self.Gate.eval()
        all_start_logits = []
        all_end_logits = []
        with torch.no_grad(), tqdm(
                total=math.ceil(len(data_loader.dataset) /
                                self.batch_size)) as progress_bar:
            for batch in data_loader:
                # Setup for forward
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                predict_start_list, predict_end_list, expert_hidden = self.Experts(
                    input_ids, attention_mask)
                gate_hidden = self.Gate(input_ids, attention_mask)
                selected_start_prob, selected_end_prob = self.get_gated_experts_prediction(
                    predict_start_list, predict_end_list, gate_hidden)
                # Forward
                all_start_logits.append(selected_start_prob)
                all_end_logits.append(selected_end_prob)
                progress_bar.update(1)

        # Get F1 and EM scores
        start_logits = torch.cat(all_start_logits).cpu().numpy()
        end_logits = torch.cat(all_end_logits).cpu().numpy()
        preds = util.postprocess_qa_predictions(data_dict,
                                                data_loader.dataset.encodings,
                                                (start_logits, end_logits))
        results = util.eval_dicts(data_dict, preds)
        results_list = [('F1', results['F1']), ('EM', results['EM'])]
        results = OrderedDict(results_list)
        return preds, results
Exemplo n.º 2
0
    def evaluate(self, model, data_loader, data_dict, return_preds=False, split='validation', experts):
        device = self.device

        model.eval()
        pred_dict = {}
        
        all_start_logits = []
        all_end_logits = []

        with torch.no_grad(), \
                tqdm(total=len(data_loader.dataset)) as progress_bar:
            #use stochastic with MoE
            for batch in data_loader:
                # Setup for forward
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                batch_size = len(input_ids)

                if experts is False:
                    outputs = model(input_ids, attention_mask=attention_mask)
                    start_logits, end_logits = outputs.start_logits, outputs.end_logits
                else:
                    outputs1 = model.expert1(input_ids, attention_mask=attention_mask)
                    start_logits1, end_logits1 = outputs1.start_logits, outputs1.end_logits
                    outputs2 = model.expert2(input_ids, attention_mask=attention_mask)
                    start_logits2, end_logits2 = outputs2.start_logits, outputs2.end_logits
                    outputs3 = model.expert3(input_ids, attention_mask=attention_mask)
                    start_logits3, end_logits3 = outputs3.start_logits, outputs3.end_logits

                    expert_weights = model.gate.forward(example)
                    start_logits = start_logits1 * expert_weights[0] + start_logits2 * expert_weights[1] + 
                    start_logits3 * expert_weights[2]
                    end_logits = end_logits1 * expert_weights[0] + end_logits2 * expert_weights[1] + 
                    end_logits3 * expert_weights[2]

                all_start_logits.append(start_logits)
                all_end_logits.append(end_logits)
                progress_bar.update(batch_size)

        # Get F1 and EM scores
        start_logits = torch.cat(all_start_logits).cpu().numpy()
        end_logits = torch.cat(all_end_logits).cpu().numpy()
        preds = util.postprocess_qa_predictions(data_dict,
                                                 data_loader.dataset.encodings,
                                                 (start_logits, end_logits))
        if split == 'validation':
            results = util.eval_dicts(data_dict, preds)
            results_list = [('F1', results['F1']),
                            ('EM', results['EM'])]
        else:
            results_list = [('F1', -1.0),
                            ('EM', -1.0)]
        results = OrderedDict(results_list)
        if return_preds:
            return preds, results
        return results
Exemplo n.º 3
0
    def evaluate(self,
                 model,
                 data_loader,
                 data_dict,
                 return_preds=False,
                 split='validation'):
        device = self.device
        # global_idx = 0
        # tbx = SummaryWriter(self.save_dir)

        model.eval()
        pred_dict = {}
        all_start_logits = []
        all_end_logits = []
        with torch.no_grad(), \
                tqdm(total=len(data_loader.dataset)) as progress_bar:
            for batch in data_loader:
                # Setup for forward
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                batch_size = len(input_ids)
                outputs = model(input_ids, attention_mask=attention_mask)
                # Forward
                start_logits, end_logits = outputs.start_logits, outputs.end_logits
                # TODO: compute loss
                # loss = outputs[0]
                # tbx.add_scalar('evaluate/NLL', loss.item(), global_idx)
                # global_idx += 1

                all_start_logits.append(start_logits)
                all_end_logits.append(end_logits)
                progress_bar.update(batch_size)

        # Get F1 and EM scores
        start_logits = torch.cat(all_start_logits).cpu().numpy()
        end_logits = torch.cat(all_end_logits).cpu().numpy()
        preds = util.postprocess_qa_predictions(data_dict,
                                                data_loader.dataset.encodings,
                                                (start_logits, end_logits))
        if split == 'validation':
            results = util.eval_dicts(data_dict, preds)
            results_list = [('F1', results['F1']), ('EM', results['EM'])]
        else:
            results_list = [('F1', -1.0), ('EM', -1.0)]
        results = OrderedDict(results_list)
        if return_preds:
            return preds, results
        return results
Exemplo n.º 4
0
    def evaluate(self, Experts, gate, data_loader, data_dict, return_preds=False, split='validation'):
        device = self.device
        for expert in Experts:
            expert.eval()
        pred_dict = {}
        all_start_logits = []
        all_end_logits = []
        with torch.no_grad(), tqdm(total=len(data_loader.dataset)) as progress_bar:
            for batch in data_loader:
                # Setup for forward
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                # Forward
                outputs = gate(input_ids)
                selected_expert = torch.argmax(outputs)
                batch_size = len(input_ids)
                # total_loss, start_logits, end_logits, distilbert_output.hidden_states
                _, start_logits, end_logits,_ = Experts[selected_expert.item()](input_ids, attention_mask=attention_mask)

                # TODO: compute loss

                all_start_logits.append(start_logits)
                all_end_logits.append(end_logits)
                progress_bar.update(batch_size)

        # Get F1 and EM scores
        start_logits = torch.cat(all_start_logits).cpu().numpy()
        end_logits = torch.cat(all_end_logits).cpu().numpy()
        preds = util.postprocess_qa_predictions(data_dict,
                                                 data_loader.dataset.encodings,
                                                 (start_logits, end_logits))
        if split == 'validation':
            results = util.eval_dicts(data_dict, preds)
            results_list = [('F1', results['F1']),
                            ('EM', results['EM'])]
        else:
            results_list = [('F1', -1.0),
                            ('EM', -1.0)]
        results = OrderedDict(results_list)
        if return_preds:
            return preds, results
        return results
Exemplo n.º 5
0
    def evaluate(self,
                 data_loader,
                 data_dict,
                 return_preds=False,
                 split='validation'):
        device = self.device

        self.model.eval()
        pred_dict = {}
        all_start_logits = []
        all_end_logits = []
        with torch.no_grad(), \
                tqdm(total=len(data_loader.dataset)) as progress_bar:
            for batch in data_loader:
                # Setup for forward
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                batch_size = len(input_ids)
                outputs = self.model(input_ids, attention_mask)
                start_logits, end_logits = outputs.start_logits, outputs.end_logits
                # Forward

                all_start_logits.append(start_logits)
                all_end_logits.append(end_logits)
                progress_bar.update(batch_size)

        # Get F1 and EM scores
        start_logits = torch.cat(all_start_logits).cpu().numpy()
        end_logits = torch.cat(all_end_logits).cpu().numpy()
        preds = util.postprocess_qa_predictions(data_dict,
                                                data_loader.dataset.encodings,
                                                (start_logits, end_logits))
        if split == 'validation':
            results = util.eval_dicts(data_dict, preds)
            results_list = [('F1', results['F1']), ('EM', results['EM'])]
        else:
            results_list = [('F1', -1.0), ('EM', -1.0)]
        results = OrderedDict(results_list)
        if return_preds:
            return preds, results
        return results
Exemplo n.º 6
0
    def evaluate(self,
                 model,
                 discriminator,
                 data_loader,
                 data_dict,
                 return_preds=False,
                 split='validation'):
        device = self.device

        model.eval()
        pred_dict = {}
        all_start_logits = []
        all_end_logits = []
        all_dis_logits = []
        all_ground_truth_data_set_ids = []
        with torch.no_grad(), \
          tqdm(total=len(data_loader.dataset)) as progress_bar:
            for batch in data_loader:
                # Setup for forward
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                data_set_ids = batch['data_set_id'].to(device)
                batch_size = len(input_ids)
                outputs = model(input_ids,
                                attention_mask=attention_mask,
                                output_hidden_states=True)
                # Forward
                start_logits, end_logits = outputs.start_logits, outputs.end_logits
                hidden_states = outputs.hidden_states[-1]
                _, dis_logits = self.forward_discriminator(
                    discriminator,
                    hidden_states,
                    data_set_ids,
                    full_adv=self.full_adv)

                # TODO: compute loss

                all_start_logits.append(start_logits)
                all_end_logits.append(end_logits)
                all_dis_logits.append(dis_logits)
                all_ground_truth_data_set_ids.append(data_set_ids)
                progress_bar.update(batch_size)

        # Get F1 and EM scores
        start_logits = torch.cat(all_start_logits).cpu().numpy()
        end_logits = torch.cat(all_end_logits).cpu().numpy()
        dis_logits = torch.cat(all_dis_logits).cpu().numpy()
        ground_truth_data_set_ids = torch.cat(
            all_ground_truth_data_set_ids).cpu().numpy()
        preds = util.postprocess_qa_predictions(data_dict,
                                                data_loader.dataset.encodings,
                                                (start_logits, end_logits))

        if split == 'validation':
            discriminator_eval_results = util.eval_discriminator(
                data_dict, ground_truth_data_set_ids, dis_logits)
            results = util.eval_dicts(data_dict, preds)
            results_list = [('F1', results['F1']), ('EM', results['EM']),
                            ('discriminator_precision',
                             discriminator_eval_results['precision'])]
        else:
            results_list = [('F1', -1.0), ('EM', -1.0),
                            ('discriminator_precision', -1.0)]
        results = OrderedDict(results_list)
        if return_preds:
            return preds, results
        return results