Exemple #1
0
    def _gather_loss_distributed(self, loss_sum, num_samples, **kwargs):
        loss_sum = xm.mesh_reduce('gather_loss', loss_sum, np.sum)
        num_samples = xm.mesh_reduce('gather_num_samples', num_samples, np.sum)

        loss = loss_sum / num_samples

        return loss, loss_sum, num_samples
Exemple #2
0
def eval(model, loss, dataloader, device, verbose, epoch, **kwargs):
    print_fn = print
    if device.type == "xla":
        import torch_xla.core.xla_model as xm

        print_fn = xm.master_print

    model.eval()
    total = 0
    correct1 = 0
    correct5 = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total += loss(output, target).item() * data.size(0)
            _, pred = output.topk(5, dim=1)
            correct = pred.eq(target.view(-1, 1).expand_as(pred))
            correct1 += correct[:, :1].sum().item()
            correct5 += correct[:, :5].sum().item()
            total_samples += data.size()[0]
    average_loss = 1.0 * total / total_samples
    accuracy1 = 100.0 * correct1 / total_samples
    accuracy5 = 100.0 * correct5 / total_samples
    print_fn(f"Epoch {epoch} evaluation: Average loss: {average_loss:.4f}, "
             f"Top 1 Accuracy: {correct1}/{total_samples} ({accuracy1:.2f}%)")

    if device.type == "xla":
        average_loss = xm.mesh_reduce("test_average_loss", average_loss,
                                      np.mean)
        accuracy1 = xm.mesh_reduce("test_accuracy1", accuracy1, np.mean)
        accuracy5 = xm.mesh_reduce("test_accuracy5", accuracy5, np.mean)
    return average_loss, accuracy1, accuracy5
Exemple #3
0
def train(args, train_loader, model, device, optimizer,scheduler, epoch, f):
    total_loss = AverageMeter()
    losses1 = AverageMeter() # start
    losses2 = AverageMeter() # end
    accuracies1 = AverageMeter() # start
    accuracies2 = AverageMeter() # end

    model.train()

    t = tqdm(train_loader, disable=not xm.is_master_ordinal())
    for step, d in enumerate(t):
        
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        token_type_ids = d["token_type_ids"].to(device)
        start_position = d["start_position"].to(device)
        end_position = d["end_position"].to(device)

        model.zero_grad()

        logits1, logits2 = model(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            token_type_ids=token_type_ids, 
            position_ids=None, 
            head_mask=None
        )

        y_true = (start_position, end_position)
        loss1, loss2 = loss_fn((logits1, logits2), (start_position, end_position))
        loss = loss1 + loss2

        acc1, n_position1 = get_position_accuracy(logits1, start_position)
        acc2, n_position2 = get_position_accuracy(logits2, end_position)

        total_loss.update(loss.item(), n_position1)
        losses1.update(loss1.item(), n_position1)
        losses2.update(loss2.item(), n_position2)
        accuracies1.update(acc1, n_position1)
        accuracies2.update(acc2, n_position2)

        
        loss.backward()
        xm.optimizer_step(optimizer)
        scheduler.step()
        print_loss = xm.mesh_reduce("loss_reduce", total_loss.avg, reduce_fn)
        print_acc1 = xm.mesh_reduce("acc1_reduce", accuracies1.avg, reduce_fn)
        print_acc2 = xm.mesh_reduce("acc2_reduce", accuracies2.avg, reduce_fn)
        t.set_description(f"Train E:{epoch+1} - Loss:{print_loss:0.2f} - acc1:{print_acc1:0.2f} - acc2:{print_acc2:0.2f}")


    log_ = f"Epoch : {epoch+1} - train_loss : {total_loss.avg} - \n \
    train_loss1 : {losses1.avg} - train_loss2 : {losses2.avg} - \n \
    train_acc1 : {accuracies1.avg} - train_acc2 : {accuracies2.avg}"

    f.write(log_ + "\n\n")
    f.flush()
    
    return total_loss.avg
    def compute(self) -> Tuple[float, float, float]:
        """
        Compute metrics with accumulated statistics

        Returns:
            tuple of metrics: precision, recall, f1 score
        """
        # ddp hotfix, could be done better
        # but metric must handle DDP on it's own
        if self._ddp_backend == "xla":
            self.statistics = {
                k: xm.mesh_reduce(k, v, np.sum)
                for k, v in self.statistics.items()
            }
        elif self._ddp_backend == "ddp":
            for key in self.statistics:
                value: List[int] = all_gather(self.statistics[key])
                value: int = sum(value)
                self.statistics[key] = value

        precision_value, recall_value, f1_value = get_binary_metrics(
            tp=self.statistics["tp"],
            fp=self.statistics["fp"],
            fn=self.statistics["fn"],
            zero_division=self.zero_division,
        )
        return precision_value, recall_value, f1_value
Exemple #5
0
def train_loop_fn(data_loader, model, optimizer, device, scheduler=None):
    model.train()
    
    train_loss = []
    
    for bi, data in enumerate(data_loader):

        ids = data['ids']
        mask = data['mask']
        targets = data['target']

        ids = ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)

        optimizer.zero_grad()
        
        outputs = model(input_ids=ids, attention_mask=mask)
        loss = loss_fn(outputs, targets)
        
        train_loss.append(loss.item())
        
        if bi % 500 == 0:
            loss_reduced = xm.mesh_reduce('loss_reduce', loss, reduce_fn)
            xm.master_print(f'bi={bi}, loss={loss_reduced:.4f}')
        
        loss.backward()
        
        xm.optimizer_step(optimizer)

        if scheduler is not None:
            scheduler.step()
            
    return train_loss
 def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
     should_stop = torch.tensor(int(should_stop),
                                device=self.lightning_module.device)
     stop = xm.mesh_reduce('stop_signal', should_stop, sum)
     rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
     should_stop = int(stop.item()) == self.world_size
     return should_stop
Exemple #7
0
 def evaluate(data_loader, model, device, eval_metric, use_tpu=False):
     losses = AverageMeter()
     final_predictions = []
     final_targets = []
     model.eval()
     with torch.no_grad():
         if use_tpu:
             para_loader = pl.ParallelLoader(data_loader, [device])
             tk0 = tqdm(para_loader.per_device_loader(device),
                        total=len(data_loader),
                        disable=xm.get_ordinal() == 0)
         else:
             tk0 = tqdm(data_loader, total=len(data_loader))
         for b_idx, data in enumerate(tk0):
             for key, value in data.items():
                 data[key] = value.to(device)
             _, loss = model(**data)
             if use_tpu:
                 reduced_loss = xm.mesh_reduce('loss_reduce', loss,
                                               reduce_fn)
                 losses.update(reduced_loss.item(), data_loader.batch_size)
             else:
                 losses.update(loss.item(), data_loader.batch_size)
             tk0.set_postfix(loss=losses.avg)
     return losses.avg
Exemple #8
0
def synchronize_rng_state(rng_type: Optional[RNGType] = None,
                          generator: Optional[torch.Generator] = None):
    # Get the proper rng state
    if rng_type == RNGType.TORCH:
        rng_state = torch.get_rng_state()
    elif rng_type == RNGType.CUDA:
        rng_state = torch.cuda.get_rng_state()
    elif rng_type == RNGType.XLA:
        assert is_tpu_available(
        ), "Can't synchronize XLA seeds on an environment without TPUs."
        rng_state = torch.tensor(xm.get_rng_state())
    elif rng_type == RNGType.GENERATOR:
        assert generator is not None, "Need a generator to synchronize its seed."
        rng_state = generator.get_state()

    # Broadcast the rng state from device 0 to other devices
    state = AcceleratorState()
    if state.distributed_type == DistributedType.TPU:
        rng_state = xm.mesh_reduce("random_seed", rng_state, lambda x: x[0])
    elif state.distributed_type == DistributedType.MULTI_GPU:
        rng_state = rng_state.to(state.device)
        torch.distributed.broadcast(rng_state, 0)
        rng_state = rng_state.cpu()

    # Set the broadcast rng state
    if rng_type == RNGType.TORCH:
        torch.set_rng_state(rng_state)
    elif rng_type == RNGType.CUDA:
        torch.cuda.set_rng_state(rng_state)
    elif rng_type == RNGType.XLA:
        xm.set_rng_state(rng_state.item())
    elif rng_type == RNGType.GENERATOR:
        generator.set_state(rng_state)
Exemple #9
0
    def test_loop_fn(model, loader, device, context):
        print("***********************")
        print("ENTERING TEST FUNCTION")
        print("***********************")
        print('Evaluating...')
        total_samples = 0
        correct = 0
        model.eval()
        for step, (data, target) in enumerate(loader):
            if step >= FLAGS.test_max_step:
                break
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            if FLAGS.mp:
                correct += pred.eq(target.view_as(pred)).sum()
            else:
                correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        if FLAGS.mp:
            this_accuracy = 100.0 * correct.item() / total_samples
            print("CALLING: mesh_reduce('test_accuracy')")
            this_accuracy = xm.mesh_reduce(
                'test_accuracy', this_accuracy, np.mean
            )
            print("BACK FROM: mesh_reduce('test_accuracy')")
        else:
            this_accuracy = 100.0 * correct / total_samples
            test_utils.print_test_update(device, this_accuracy)
        print("***********************")
        print("LEAVING TEST FUNCTION")
        print("***********************")
        return this_accuracy
Exemple #10
0
 def evaluate(self, data_loader, return_predictions=False):
     losses = AverageMeter()
     print_idx = int(len(data_loader) * self.tpu_print / 100)
     self.model.eval()
     final_predictions = []
     with torch.no_grad():
         if self.use_tpu:
             para_loader = pl.ParallelLoader(data_loader, [self.device])
             tk0 = para_loader.per_device_loader(self.device)
         else:
             tk0 = tqdm(data_loader, total=len(data_loader))
         for b_idx, data in enumerate(tk0):
             for key, value in data.items():
                 data[key] = value.to(self.device)
             batch_preds, loss = self.model(**data)
             if return_predictions:
                 final_predictions.append(batch_preds)
             if self.use_tpu:
                 reduced_loss = xm.mesh_reduce("loss_reduce", loss, reduce_fn)
                 losses.update(reduced_loss.item(), data_loader.batch_size)
             else:
                 if self.use_mean_loss:
                     loss = loss.mean()
                 losses.update(loss.item(), data_loader.batch_size)
             if not self.use_tpu:
                 tk0.set_postfix(loss=losses.avg)
             else:
                 if b_idx % print_idx == 0 or b_idx == len(data_loader):
                     xm.master_print(
                         f"{datetime.datetime.now()}: Batch {b_idx} / {len(data_loader)}, loss={losses.avg}"
                     )
         if not self.use_tpu:
             tk0.close()
     return losses.avg, final_predictions
Exemple #11
0
    def train(self, data_loader):
        losses = AverageMeter()
        self.model.train()
        print_idx = int(len(data_loader) * self.tpu_print / 100)
        if self.accumulation_steps > 1:
            self.optimizer.zero_grad()
        if self.use_tpu:
            para_loader = pl.ParallelLoader(data_loader, [self.device])
            tk0 = para_loader.per_device_loader(self.device)
        else:
            tk0 = tqdm(data_loader, total=len(data_loader))

        for b_idx, data in enumerate(tk0):
            if self.accumulation_steps == 1 and b_idx == 0:
                self.optimizer.zero_grad()

            if self.model_fn is None:
                for key, value in data.items():
                    data[key] = value.to(self.device)
                _, loss = self.model(**data)
            else:
                loss = self.model_fn(data, self.device, self.model)

            if not self.use_tpu:
                with torch.set_grad_enabled(True):
                    if self.fp16:
                        with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    if (b_idx + 1) % self.accumulation_steps == 0:
                        self.optimizer.step()
                        if self.scheduler is not None:
                            self.scheduler.step()
                        if b_idx > 0:
                            self.optimizer.zero_grad()
            else:
                loss.backward()
                xm.optimizer_step(self.optimizer)
                if self.scheduler is not None:
                    self.scheduler.step()
                if b_idx > 0:
                    self.optimizer.zero_grad()
            if self.use_tpu:
                reduced_loss = xm.mesh_reduce("loss_reduce", loss, reduce_fn)
                losses.update(reduced_loss.item(), data_loader.batch_size)
            else:
                losses.update(loss.item(), data_loader.batch_size)

            if not self.use_tpu:
                tk0.set_postfix(loss=losses.avg)
            else:
                if b_idx % print_idx == 0 or b_idx == len(data_loader):
                    xm.master_print(
                        f"{datetime.datetime.now()}: Batch {b_idx} / {len(data_loader)}, loss={losses.avg}"
                    )
        if not self.use_tpu:
            tk0.close()
        return losses.avg
Exemple #12
0
def _test_scalar():
    def reduce_add(vlist):
        return sum(vlist)

    svalue = 1.25
    rvalue = xm.mesh_reduce('test_mp_mesh_reduce._test_scalar', svalue,
                            reduce_add)
    assert rvalue == svalue * xm.xrt_world_size()
Exemple #13
0
def _tpu_gather(tensor, name="tensor"):
    if isinstance(tensor, (list, tuple)):
        return type(tensor)(_tpu_gather(t, name=f"{name}_{i}") for i, t in enumerate(tensor))
    elif isinstance(tensor, dict):
        return type(tensor)({k: _tpu_gather(v, name=f"{name}_{k}") for k, v in tensor.items()})
    elif not isinstance(tensor, torch.Tensor):
        raise TypeError(f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors.")
    return xm.mesh_reduce(name, tensor, torch.cat)
Exemple #14
0
def _test_tensor():
    def reduce_add(vlist):
        return torch.stack(vlist).sum(dim=0)

    tvalue = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32)
    rvalue = xm.mesh_reduce('test_mp_mesh_reduce._test_tensor', tvalue,
                            reduce_add)
    assert rvalue.allclose(tvalue * xm.xrt_world_size())
def nested_xla_mesh_reduce(tensors, name):
    if is_torch_tpu_available():
        import torch_xla.core.xla_model as xm

        if isinstance(tensors, (list, tuple)):
            return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
        return xm.mesh_reduce(name, tensors, torch.cat)
    else:
        raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
 def early_stopping_should_stop(self, pl_module):
     stop = torch.tensor(int(self.trainer.should_stop),
                         device=pl_module.device,
                         dtype=torch.int32)
     stop = xm.mesh_reduce("stop_signal", stop, sum)
     torch_xla.core.xla_model.rendezvous(
         "pl.EarlyStoppingCallback.stop_distributed_training_check")
     should_stop = int(stop.item()) == self.trainer.world_size
     return should_stop
Exemple #17
0
 def sync_metrics(self, metrics: Dict) -> Dict:
     """Syncs ``metrics`` over ``world_size`` in the distributed mode."""
     metrics = {
         k: xm.mesh_reduce(k,
                           v.item() if isinstance(v, torch.Tensor) else v,
                           np.mean)
         for k, v in metrics.items()
     }
     return metrics
Exemple #18
0
def _tpu_broadcast(tensor, src=0, name="broadcast tensor"):
    if isinstance(tensor, (list, tuple)):
        return honor_type(tensor, (_tpu_broadcast(t, name=f"{name}_{i}")
                                   for i, t in enumerate(tensor)))
    elif isinstance(tensor, Mapping):
        return type(tensor)({
            k: _tpu_broadcast(v, name=f"{name}_{k}")
            for k, v in tensor.items()
        })
    return xm.mesh_reduce(name, tensor, lambda x: x[src])
    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct.item() / total_samples
        accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
        return accuracy
    def _stop_distributed_training(self, trainer, pl_module):

        # in ddp make sure all processes stop when one is flagged
        if trainer.use_ddp or trainer.use_ddp2:
            stop = torch.tensor(int(trainer.should_stop), device=pl_module.device)
            dist.all_reduce(stop, op=dist.reduce_op.SUM)
            dist.barrier()
            trainer.should_stop = stop == trainer.world_size

        if trainer.use_tpu:
            stop = torch.tensor(int(trainer.should_stop), device=pl_module.device, dtype=torch.int32)
            stop = xm.mesh_reduce("stop_signal", stop, torch.cat)
            torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
            trainer.should_stop = int(stop.item()) == trainer.world_size
 def test_loop_fn(loader, epoch):
     total_samples, correct = 0, 0
     model.eval()
     for step, (data, target) in enumerate(loader):
         output = model(data)
         pred = output.max(1, keepdim=True)[1]
         correct += pred.eq(target.view_as(pred)).sum()
         total_samples += data.size()[0]
         if step % FLAGS.log_steps == 0:
             xm.add_step_closure(test_utils.print_test_update,
                                 args=(device, None, epoch, step))
     accuracy = 100.0 * correct.item() / total_samples
     accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
     return accuracy
Exemple #22
0
def _tpu_gather(tensor, name="gather tensor"):
    if isinstance(tensor, (list, tuple)):
        return honor_type(tensor, (_tpu_gather(t, name=f"{name}_{i}")
                                   for i, t in enumerate(tensor)))
    elif isinstance(tensor, Mapping):
        return type(tensor)(
            {k: _tpu_gather(v, name=f"{name}_{k}")
             for k, v in tensor.items()})
    elif not isinstance(tensor, torch.Tensor):
        raise TypeError(
            f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
        )
    if tensor.ndim == 0:
        tensor = tensor.clone()[None]
    return xm.mesh_reduce(name, tensor, torch.cat)
    def _stop_distributed_training(self, trainer, pl_module):
        if trainer.use_ddp or trainer.use_ddp2:
            stop = torch.tensor(int(trainer.should_stop),
                                device=pl_module.device)
            dist.all_reduce(stop, op=dist.reduce_op.SUM)
            dist.barrier()
            trainer.should_stop = stop == trainer.world_size

        if trainer.use_tpu:
            stop = torch.tensor(int(trainer.should_stop),
                                device=pl_module.device,
                                dtype=torch.int32)
            stop = xm.mesh_reduce("stop_signal", stop,
                                  lambda xs: torch.stack(xs).sum())
            torch_xla.core.xla_model.rendezvous(
                "pl.EarlyStoppingCallback.stop_distributed_training_check")
            trainer.should_stop = int(stop.item()) == trainer.world_size
Exemple #24
0
    def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
        if not isinstance(output, torch.Tensor):
            output = torch.tensor(output, device=self.device)

        _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
        _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
        if _invalid_reduce_op or _invalid_reduce_op_str:
            raise MisconfigurationException(
                "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation."
            )

        output = xm.mesh_reduce('reduce', output, sum)

        if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
            output = output / self.world_size

        return output
Exemple #25
0
 def mean_reduce_ddp_metrics(self, metrics: Dict) -> Dict:
     """Syncs ``metrics`` over ``world_size`` in the distributed mode."""
     if self.state.distributed_type in [
         DistributedType.MULTI_CPU,
         DistributedType.MULTI_GPU,
     ]:
         metrics = {
             k: mean_reduce(
                 torch.tensor(v, device=self.device),
                 world_size=self.state.num_processes,
             )
             for k, v in metrics.items()
         }
     elif self.state.distributed_type == DistributedType.TPU:
         metrics = {
             k: xm.mesh_reduce(
                 k, v.item() if isinstance(v, torch.Tensor) else v, np.mean
             )
             for k, v in metrics.items()
         }
     return metrics
Exemple #26
0
def broadcast_object_list(object_list, from_process: int = 0):
    """
    Broadcast a list of picklable objects form one process to the others.

    Args:
        object_list (list of picklable objects):
            The list of objects to broadcast. This list will be modified inplace.
        from_process (:obj:`int`, `optional`, defaults to 0):
            The process from which to send the data.

    Returns:
        The same list containing the objects from process 0.
    """
    if AcceleratorState().distributed_type == DistributedType.TPU:
        for i, obj in enumerate(object_list):
            object_list[i] = xm.mesh_reduce(
                "accelerate.utils.broadcast_object_list", obj,
                lambda x: x[from_process])
    elif AcceleratorState().distributed_type == DistributedType.MULTI_GPU:
        torch.distributed.broadcast_object_list(object_list, src=from_process)
    elif AcceleratorState().distributed_type == DistributedType.MULTI_CPU:
        torch.distributed.broadcast_object_list(object_list, src=from_process)
    return object_list
Exemple #27
0
    def _prediction_loop(
            self,
            dataloader: DataLoader,
            description: str,
            prediction_loss_only: Optional[bool] = None) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`.

        Works both with or without labels.
        """

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = self.model
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        batch_size = dataloader.batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
        eval_losses: List[float] = []
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
        model.eval()

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(
                dataloader,
                [self.args.device]).per_device_loader(self.args.device)

        for inputs in tqdm(dataloader, desc=description):
            has_labels = any(
                inputs.get(k) is not None
                for k in ["labels", "lm_labels", "masked_lm_labels"])

            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)

            with torch.no_grad():
                outputs = model(**inputs)
                if has_labels:
                    step_eval_loss, logits = outputs[:2]
                    eval_losses += [step_eval_loss.mean().item()]
                else:
                    logits = outputs[0]

            if not prediction_loss_only:
                if preds is None:
                    preds = logits.detach()
                else:
                    preds = torch.cat((preds, logits.detach()), dim=0)
                if inputs.get("labels") is not None:
                    if label_ids is None:
                        label_ids = inputs["labels"].detach()
                    else:
                        label_ids = torch.cat(
                            (label_ids, inputs["labels"].detach()), dim=0)

        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = self.distributed_concat(
                    preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = self.distributed_concat(
                    label_ids,
                    num_total_examples=self.num_examples(dataloader))
        elif is_torch_tpu_available():
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids,
                                           torch.cat)

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(
                EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds,
                                label_ids=label_ids,
                                metrics=metrics)
Exemple #28
0
def evaluate(args, model, tokenizer, prefix="", disable_logging=False):
    """Evaluate the model"""
    if xm.is_master_ordinal():
        # Only master writes to Tensorboard
        tb_writer = SummaryWriter(args.tensorboard_logdir)

    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
    eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)

    results = {}
    for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
        eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
        eval_sampler = get_sampler(eval_dataset)

        if not os.path.exists(eval_output_dir):
            os.makedirs(eval_output_dir)

        dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, shuffle=False)
        eval_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)

        # Eval!
        logger.info("***** Running evaluation {} *****".format(prefix))
        logger.info("  Num examples = %d", len(dataloader) * args.eval_batch_size)
        logger.info("  Batch size = %d", args.eval_batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        for batch in tqdm(eval_dataloader, desc="Evaluating", disable=disable_logging):
            model.eval()

            with torch.no_grad():
                inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
                if args.model_type != "distilbert":
                    # XLM, DistilBERT and RoBERTa don't use segment_ids
                    inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None
                outputs = model(**inputs)
                batch_eval_loss, logits = outputs[:2]

                eval_loss += batch_eval_loss
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs["labels"].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)

        # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
        preds = xm.mesh_reduce("eval_preds", preds, np.concatenate)
        out_label_ids = xm.mesh_reduce("eval_out_label_ids", out_label_ids, np.concatenate)

        eval_loss = eval_loss / nb_eval_steps
        if args.output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif args.output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(eval_task, preds, out_label_ids)
        results.update(result)
        results["eval_loss"] = eval_loss.item()

        output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
        if xm.is_master_ordinal():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results {} *****".format(prefix))
                for key in sorted(results.keys()):
                    logger.info("  %s = %s", key, str(results[key]))
                    writer.write("%s = %s\n" % (key, str(results[key])))
                    tb_writer.add_scalar(f"{eval_task}/{key}", results[key])

    if args.metrics_debug:
        # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
        xm.master_print(met.metrics_report())

    if xm.is_master_ordinal():
        tb_writer.close()

    return results
def train_fn(df):
    size = 1; torch.manual_seed(42)

    df = shuffle(df)
    split = np.int32(SPLIT*len(df))
    val_df, train_df = df[split:], df[:split]

    val_df = val_df.reset_index(drop=True)
    val_set = QuoraDataset(val_df, tokenizer)
    val_sampler = DistributedSampler(val_set, num_replicas=8,
                                     rank=xm.get_ordinal(), shuffle=True)

    train_df = train_df.reset_index(drop=True)
    train_set = QuoraDataset(train_df, tokenizer)
    train_sampler = DistributedSampler(train_set, num_replicas=8,
                                       rank=xm.get_ordinal(), shuffle=True)
    
    val_loader = DataLoader(val_set, VAL_BATCH_SIZE,
                            sampler=val_sampler, num_workers=0, drop_last=True)

    train_loader = DataLoader(train_set, BATCH_SIZE,
                              sampler=train_sampler, num_workers=0, drop_last=True)

    device = xm.xla_device()
    network = Roberta().to(device)
    optimizer = Adam([{'params': network.roberta.parameters(), 'lr': LR[0]*size},
                      {'params': network.dense_output.parameters(), 'lr': LR[1]*size}])

    val_losses, val_f1s = [], []
    train_losses, train_f1s = [], []
    
    start = time.time()
    xm.master_print("STARTING TRAINING ...\n")

    for epoch in range(EPOCHS):

        batch = 1
        network.train()
        fonts = (fg(48), attr('reset'))
        xm.master_print(("EPOCH %s" + str(epoch+1) + "%s") % fonts)

        val_parallel = pl.ParallelLoader(val_loader, [device]).per_device_loader(device)
        train_parallel = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
        
        for train_batch in train_parallel:
            train_targ, train_in, train_att = train_batch
            
            network = network.to(device)
            train_in = train_in.to(device)
            train_att = train_att.to(device)
            train_targ = train_targ.to(device)

            train_preds = network.forward(train_in, train_att)
            train_loss = bce(train_preds, train_targ)/len(train_preds)
            train_f1 = f1_score(train_preds, train_targ.squeeze(dim=1))

            optimizer.zero_grad()
            train_loss.backward()
            xm.optimizer_step(optimizer)
            
            end = time.time()
            batch = batch + 1
            is_print = batch % 10 == 1
            f1 = np.round(train_f1.item(), 3)
            if is_print: print_metric(f1, batch, None, start, end, metric="F1", typ="Train")

        val_loss, val_f1, val_points = 0, 0, 0

        network.eval()
        with torch.no_grad():
            for val_batch in val_parallel:
                targ, val_in, val_att = val_batch

                targ = targ.to(device)
                val_in = val_in.to(device)
                val_att = val_att.to(device)
                network = network.to(device)
                pred = network.forward(val_in, val_att)

                val_points += len(targ)
                val_loss += bce(pred, targ).item()
                val_f1 += f1_score(pred, targ.squeeze(dim=1)).item()*len(pred)
        
        end = time.time()
        val_f1 /= val_points
        val_loss /= val_points
        f1 = xm.mesh_reduce('f1', val_f1, lambda x: sum(x)/len(x))
        loss = xm.mesh_reduce('loss', val_loss, lambda x: sum(x)/len(x))
        print_metric(np.round(f1, 3), None, epoch, start, end, metric="F1", typ="Val")
    
        xm.master_print("")
        val_f1s.append(f1); train_f1s.append(train_f1.item())
        val_losses.append(loss); train_losses.append(train_loss.item())

    xm.master_print("ENDING TRAINING ...")
    xm.save(network.state_dict(), MODEL_SAVE_PATH); del network; gc.collect()
    
    metric_lists = [val_losses, train_losses, val_f1s, train_f1s]
    metric_names = ['val_loss_', 'train_loss_', 'val_f1_', 'train_f1_']
    
    for i, metric_list in enumerate(metric_lists):
        for j, metric_value in enumerate(metric_list):
            torch.save(metric_value, metric_names[i] + str(j) + '.pt')
Exemple #30
0
    def prediction_loop(
            self,
            dataloader: DataLoader,
            description: str,
            prediction_loss_only: Optional[bool] = None) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.

        Works both with or without labels.
        """
        if hasattr(self, "_prediction_loop"):
            warnings.warn(
                "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
                FutureWarning,
            )
            return self._prediction_loop(
                dataloader,
                description,
                prediction_loss_only=prediction_loss_only)

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = self.model
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        batch_size = dataloader.batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
        eval_losses: List[float] = []
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
        model.eval()

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(
                dataloader,
                [self.args.device]).per_device_loader(self.args.device)

        if self.args.past_index >= 0:
            self._past = None

        for inputs in tqdm(dataloader, desc=description):
            loss, logits, labels = self.prediction_step(
                model, inputs, prediction_loss_only)
            if loss is not None:
                eval_losses.append(loss)
            if logits is not None:
                preds = logits if preds is None else torch.cat(
                    (preds, logits), dim=0)
            if labels is not None:
                label_ids = labels if label_ids is None else torch.cat(
                    (label_ids, labels), dim=0)

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = self.distributed_concat(
                    preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = self.distributed_concat(
                    label_ids,
                    num_total_examples=self.num_examples(dataloader))
        elif is_torch_tpu_available():
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids,
                                           torch.cat)

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(
                EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds,
                                label_ids=label_ids,
                                metrics=metrics)