def _gather_loss_distributed(self, loss_sum, num_samples, **kwargs): loss_sum = xm.mesh_reduce('gather_loss', loss_sum, np.sum) num_samples = xm.mesh_reduce('gather_num_samples', num_samples, np.sum) loss = loss_sum / num_samples return loss, loss_sum, num_samples
def eval(model, loss, dataloader, device, verbose, epoch, **kwargs): print_fn = print if device.type == "xla": import torch_xla.core.xla_model as xm print_fn = xm.master_print model.eval() total = 0 correct1 = 0 correct5 = 0 total_samples = 0 with torch.no_grad(): for data, target in dataloader: data, target = data.to(device), target.to(device) output = model(data) total += loss(output, target).item() * data.size(0) _, pred = output.topk(5, dim=1) correct = pred.eq(target.view(-1, 1).expand_as(pred)) correct1 += correct[:, :1].sum().item() correct5 += correct[:, :5].sum().item() total_samples += data.size()[0] average_loss = 1.0 * total / total_samples accuracy1 = 100.0 * correct1 / total_samples accuracy5 = 100.0 * correct5 / total_samples print_fn(f"Epoch {epoch} evaluation: Average loss: {average_loss:.4f}, " f"Top 1 Accuracy: {correct1}/{total_samples} ({accuracy1:.2f}%)") if device.type == "xla": average_loss = xm.mesh_reduce("test_average_loss", average_loss, np.mean) accuracy1 = xm.mesh_reduce("test_accuracy1", accuracy1, np.mean) accuracy5 = xm.mesh_reduce("test_accuracy5", accuracy5, np.mean) return average_loss, accuracy1, accuracy5
def train(args, train_loader, model, device, optimizer,scheduler, epoch, f): total_loss = AverageMeter() losses1 = AverageMeter() # start losses2 = AverageMeter() # end accuracies1 = AverageMeter() # start accuracies2 = AverageMeter() # end model.train() t = tqdm(train_loader, disable=not xm.is_master_ordinal()) for step, d in enumerate(t): input_ids = d["input_ids"].to(device) attention_mask = d["attention_mask"].to(device) token_type_ids = d["token_type_ids"].to(device) start_position = d["start_position"].to(device) end_position = d["end_position"].to(device) model.zero_grad() logits1, logits2 = model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=None, head_mask=None ) y_true = (start_position, end_position) loss1, loss2 = loss_fn((logits1, logits2), (start_position, end_position)) loss = loss1 + loss2 acc1, n_position1 = get_position_accuracy(logits1, start_position) acc2, n_position2 = get_position_accuracy(logits2, end_position) total_loss.update(loss.item(), n_position1) losses1.update(loss1.item(), n_position1) losses2.update(loss2.item(), n_position2) accuracies1.update(acc1, n_position1) accuracies2.update(acc2, n_position2) loss.backward() xm.optimizer_step(optimizer) scheduler.step() print_loss = xm.mesh_reduce("loss_reduce", total_loss.avg, reduce_fn) print_acc1 = xm.mesh_reduce("acc1_reduce", accuracies1.avg, reduce_fn) print_acc2 = xm.mesh_reduce("acc2_reduce", accuracies2.avg, reduce_fn) t.set_description(f"Train E:{epoch+1} - Loss:{print_loss:0.2f} - acc1:{print_acc1:0.2f} - acc2:{print_acc2:0.2f}") log_ = f"Epoch : {epoch+1} - train_loss : {total_loss.avg} - \n \ train_loss1 : {losses1.avg} - train_loss2 : {losses2.avg} - \n \ train_acc1 : {accuracies1.avg} - train_acc2 : {accuracies2.avg}" f.write(log_ + "\n\n") f.flush() return total_loss.avg
def compute(self) -> Tuple[float, float, float]: """ Compute metrics with accumulated statistics Returns: tuple of metrics: precision, recall, f1 score """ # ddp hotfix, could be done better # but metric must handle DDP on it's own if self._ddp_backend == "xla": self.statistics = { k: xm.mesh_reduce(k, v, np.sum) for k, v in self.statistics.items() } elif self._ddp_backend == "ddp": for key in self.statistics: value: List[int] = all_gather(self.statistics[key]) value: int = sum(value) self.statistics[key] = value precision_value, recall_value, f1_value = get_binary_metrics( tp=self.statistics["tp"], fp=self.statistics["fp"], fn=self.statistics["fn"], zero_division=self.zero_division, ) return precision_value, recall_value, f1_value
def train_loop_fn(data_loader, model, optimizer, device, scheduler=None): model.train() train_loss = [] for bi, data in enumerate(data_loader): ids = data['ids'] mask = data['mask'] targets = data['target'] ids = ids.to(device, dtype=torch.long) mask = mask.to(device, dtype=torch.long) targets = targets.to(device, dtype=torch.float) optimizer.zero_grad() outputs = model(input_ids=ids, attention_mask=mask) loss = loss_fn(outputs, targets) train_loss.append(loss.item()) if bi % 500 == 0: loss_reduced = xm.mesh_reduce('loss_reduce', loss, reduce_fn) xm.master_print(f'bi={bi}, loss={loss_reduced:.4f}') loss.backward() xm.optimizer_step(optimizer) if scheduler is not None: scheduler.step() return train_loss
def reduce_early_stopping_decision(self, should_stop: bool) -> bool: should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device) stop = xm.mesh_reduce('stop_signal', should_stop, sum) rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") should_stop = int(stop.item()) == self.world_size return should_stop
def evaluate(data_loader, model, device, eval_metric, use_tpu=False): losses = AverageMeter() final_predictions = [] final_targets = [] model.eval() with torch.no_grad(): if use_tpu: para_loader = pl.ParallelLoader(data_loader, [device]) tk0 = tqdm(para_loader.per_device_loader(device), total=len(data_loader), disable=xm.get_ordinal() == 0) else: tk0 = tqdm(data_loader, total=len(data_loader)) for b_idx, data in enumerate(tk0): for key, value in data.items(): data[key] = value.to(device) _, loss = model(**data) if use_tpu: reduced_loss = xm.mesh_reduce('loss_reduce', loss, reduce_fn) losses.update(reduced_loss.item(), data_loader.batch_size) else: losses.update(loss.item(), data_loader.batch_size) tk0.set_postfix(loss=losses.avg) return losses.avg
def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optional[torch.Generator] = None): # Get the proper rng state if rng_type == RNGType.TORCH: rng_state = torch.get_rng_state() elif rng_type == RNGType.CUDA: rng_state = torch.cuda.get_rng_state() elif rng_type == RNGType.XLA: assert is_tpu_available( ), "Can't synchronize XLA seeds on an environment without TPUs." rng_state = torch.tensor(xm.get_rng_state()) elif rng_type == RNGType.GENERATOR: assert generator is not None, "Need a generator to synchronize its seed." rng_state = generator.get_state() # Broadcast the rng state from device 0 to other devices state = AcceleratorState() if state.distributed_type == DistributedType.TPU: rng_state = xm.mesh_reduce("random_seed", rng_state, lambda x: x[0]) elif state.distributed_type == DistributedType.MULTI_GPU: rng_state = rng_state.to(state.device) torch.distributed.broadcast(rng_state, 0) rng_state = rng_state.cpu() # Set the broadcast rng state if rng_type == RNGType.TORCH: torch.set_rng_state(rng_state) elif rng_type == RNGType.CUDA: torch.cuda.set_rng_state(rng_state) elif rng_type == RNGType.XLA: xm.set_rng_state(rng_state.item()) elif rng_type == RNGType.GENERATOR: generator.set_state(rng_state)
def test_loop_fn(model, loader, device, context): print("***********************") print("ENTERING TEST FUNCTION") print("***********************") print('Evaluating...') total_samples = 0 correct = 0 model.eval() for step, (data, target) in enumerate(loader): if step >= FLAGS.test_max_step: break output = model(data) pred = output.max(1, keepdim=True)[1] if FLAGS.mp: correct += pred.eq(target.view_as(pred)).sum() else: correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] if FLAGS.mp: this_accuracy = 100.0 * correct.item() / total_samples print("CALLING: mesh_reduce('test_accuracy')") this_accuracy = xm.mesh_reduce( 'test_accuracy', this_accuracy, np.mean ) print("BACK FROM: mesh_reduce('test_accuracy')") else: this_accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, this_accuracy) print("***********************") print("LEAVING TEST FUNCTION") print("***********************") return this_accuracy
def evaluate(self, data_loader, return_predictions=False): losses = AverageMeter() print_idx = int(len(data_loader) * self.tpu_print / 100) self.model.eval() final_predictions = [] with torch.no_grad(): if self.use_tpu: para_loader = pl.ParallelLoader(data_loader, [self.device]) tk0 = para_loader.per_device_loader(self.device) else: tk0 = tqdm(data_loader, total=len(data_loader)) for b_idx, data in enumerate(tk0): for key, value in data.items(): data[key] = value.to(self.device) batch_preds, loss = self.model(**data) if return_predictions: final_predictions.append(batch_preds) if self.use_tpu: reduced_loss = xm.mesh_reduce("loss_reduce", loss, reduce_fn) losses.update(reduced_loss.item(), data_loader.batch_size) else: if self.use_mean_loss: loss = loss.mean() losses.update(loss.item(), data_loader.batch_size) if not self.use_tpu: tk0.set_postfix(loss=losses.avg) else: if b_idx % print_idx == 0 or b_idx == len(data_loader): xm.master_print( f"{datetime.datetime.now()}: Batch {b_idx} / {len(data_loader)}, loss={losses.avg}" ) if not self.use_tpu: tk0.close() return losses.avg, final_predictions
def train(self, data_loader): losses = AverageMeter() self.model.train() print_idx = int(len(data_loader) * self.tpu_print / 100) if self.accumulation_steps > 1: self.optimizer.zero_grad() if self.use_tpu: para_loader = pl.ParallelLoader(data_loader, [self.device]) tk0 = para_loader.per_device_loader(self.device) else: tk0 = tqdm(data_loader, total=len(data_loader)) for b_idx, data in enumerate(tk0): if self.accumulation_steps == 1 and b_idx == 0: self.optimizer.zero_grad() if self.model_fn is None: for key, value in data.items(): data[key] = value.to(self.device) _, loss = self.model(**data) else: loss = self.model_fn(data, self.device, self.model) if not self.use_tpu: with torch.set_grad_enabled(True): if self.fp16: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if (b_idx + 1) % self.accumulation_steps == 0: self.optimizer.step() if self.scheduler is not None: self.scheduler.step() if b_idx > 0: self.optimizer.zero_grad() else: loss.backward() xm.optimizer_step(self.optimizer) if self.scheduler is not None: self.scheduler.step() if b_idx > 0: self.optimizer.zero_grad() if self.use_tpu: reduced_loss = xm.mesh_reduce("loss_reduce", loss, reduce_fn) losses.update(reduced_loss.item(), data_loader.batch_size) else: losses.update(loss.item(), data_loader.batch_size) if not self.use_tpu: tk0.set_postfix(loss=losses.avg) else: if b_idx % print_idx == 0 or b_idx == len(data_loader): xm.master_print( f"{datetime.datetime.now()}: Batch {b_idx} / {len(data_loader)}, loss={losses.avg}" ) if not self.use_tpu: tk0.close() return losses.avg
def _test_scalar(): def reduce_add(vlist): return sum(vlist) svalue = 1.25 rvalue = xm.mesh_reduce('test_mp_mesh_reduce._test_scalar', svalue, reduce_add) assert rvalue == svalue * xm.xrt_world_size()
def _tpu_gather(tensor, name="tensor"): if isinstance(tensor, (list, tuple)): return type(tensor)(_tpu_gather(t, name=f"{name}_{i}") for i, t in enumerate(tensor)) elif isinstance(tensor, dict): return type(tensor)({k: _tpu_gather(v, name=f"{name}_{k}") for k, v in tensor.items()}) elif not isinstance(tensor, torch.Tensor): raise TypeError(f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors.") return xm.mesh_reduce(name, tensor, torch.cat)
def _test_tensor(): def reduce_add(vlist): return torch.stack(vlist).sum(dim=0) tvalue = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32) rvalue = xm.mesh_reduce('test_mp_mesh_reduce._test_tensor', tvalue, reduce_add) assert rvalue.allclose(tvalue * xm.xrt_world_size())
def nested_xla_mesh_reduce(tensors, name): if is_torch_tpu_available(): import torch_xla.core.xla_model as xm if isinstance(tensors, (list, tuple)): return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors)) return xm.mesh_reduce(name, tensors, torch.cat) else: raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
def early_stopping_should_stop(self, pl_module): stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device, dtype=torch.int32) stop = xm.mesh_reduce("stop_signal", stop, sum) torch_xla.core.xla_model.rendezvous( "pl.EarlyStoppingCallback.stop_distributed_training_check") should_stop = int(stop.item()) == self.trainer.world_size return should_stop
def sync_metrics(self, metrics: Dict) -> Dict: """Syncs ``metrics`` over ``world_size`` in the distributed mode.""" metrics = { k: xm.mesh_reduce(k, v.item() if isinstance(v, torch.Tensor) else v, np.mean) for k, v in metrics.items() } return metrics
def _tpu_broadcast(tensor, src=0, name="broadcast tensor"): if isinstance(tensor, (list, tuple)): return honor_type(tensor, (_tpu_broadcast(t, name=f"{name}_{i}") for i, t in enumerate(tensor))) elif isinstance(tensor, Mapping): return type(tensor)({ k: _tpu_broadcast(v, name=f"{name}_{k}") for k, v in tensor.items() }) return xm.mesh_reduce(name, tensor, lambda x: x[src])
def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy
def _stop_distributed_training(self, trainer, pl_module): # in ddp make sure all processes stop when one is flagged if trainer.use_ddp or trainer.use_ddp2: stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) dist.all_reduce(stop, op=dist.reduce_op.SUM) dist.barrier() trainer.should_stop = stop == trainer.world_size if trainer.use_tpu: stop = torch.tensor(int(trainer.should_stop), device=pl_module.device, dtype=torch.int32) stop = xm.mesh_reduce("stop_signal", stop, torch.cat) torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") trainer.should_stop = int(stop.item()) == trainer.world_size
def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] if step % FLAGS.log_steps == 0: xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step)) accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy
def _tpu_gather(tensor, name="gather tensor"): if isinstance(tensor, (list, tuple)): return honor_type(tensor, (_tpu_gather(t, name=f"{name}_{i}") for i, t in enumerate(tensor))) elif isinstance(tensor, Mapping): return type(tensor)( {k: _tpu_gather(v, name=f"{name}_{k}") for k, v in tensor.items()}) elif not isinstance(tensor, torch.Tensor): raise TypeError( f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." ) if tensor.ndim == 0: tensor = tensor.clone()[None] return xm.mesh_reduce(name, tensor, torch.cat)
def _stop_distributed_training(self, trainer, pl_module): if trainer.use_ddp or trainer.use_ddp2: stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) dist.all_reduce(stop, op=dist.reduce_op.SUM) dist.barrier() trainer.should_stop = stop == trainer.world_size if trainer.use_tpu: stop = torch.tensor(int(trainer.should_stop), device=pl_module.device, dtype=torch.int32) stop = xm.mesh_reduce("stop_signal", stop, lambda xs: torch.stack(xs).sum()) torch_xla.core.xla_model.rendezvous( "pl.EarlyStoppingCallback.stop_distributed_training_check") trainer.should_stop = int(stop.item()) == trainer.world_size
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): if not isinstance(output, torch.Tensor): output = torch.tensor(output, device=self.device) _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") if _invalid_reduce_op or _invalid_reduce_op_str: raise MisconfigurationException( "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation." ) output = xm.mesh_reduce('reduce', output, sum) if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): output = output / self.world_size return output
def mean_reduce_ddp_metrics(self, metrics: Dict) -> Dict: """Syncs ``metrics`` over ``world_size`` in the distributed mode.""" if self.state.distributed_type in [ DistributedType.MULTI_CPU, DistributedType.MULTI_GPU, ]: metrics = { k: mean_reduce( torch.tensor(v, device=self.device), world_size=self.state.num_processes, ) for k, v in metrics.items() } elif self.state.distributed_type == DistributedType.TPU: metrics = { k: xm.mesh_reduce( k, v.item() if isinstance(v, torch.Tensor) else v, np.mean ) for k, v in metrics.items() } return metrics
def broadcast_object_list(object_list, from_process: int = 0): """ Broadcast a list of picklable objects form one process to the others. Args: object_list (list of picklable objects): The list of objects to broadcast. This list will be modified inplace. from_process (:obj:`int`, `optional`, defaults to 0): The process from which to send the data. Returns: The same list containing the objects from process 0. """ if AcceleratorState().distributed_type == DistributedType.TPU: for i, obj in enumerate(object_list): object_list[i] = xm.mesh_reduce( "accelerate.utils.broadcast_object_list", obj, lambda x: x[from_process]) elif AcceleratorState().distributed_type == DistributedType.MULTI_GPU: torch.distributed.broadcast_object_list(object_list, src=from_process) elif AcceleratorState().distributed_type == DistributedType.MULTI_CPU: torch.distributed.broadcast_object_list(object_list, src=from_process) return object_list
def _prediction_loop( self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None) -> PredictionOutput: """ Prediction/evaluation loop, shared by `evaluate()` and `predict()`. Works both with or without labels. """ prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only model = self.model # multi-gpu eval if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) else: model = self.model # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. batch_size = dataloader.batch_size logger.info("***** Running %s *****", description) logger.info(" Num examples = %d", self.num_examples(dataloader)) logger.info(" Batch size = %d", batch_size) eval_losses: List[float] = [] preds: torch.Tensor = None label_ids: torch.Tensor = None model.eval() if is_torch_tpu_available(): dataloader = pl.ParallelLoader( dataloader, [self.args.device]).per_device_loader(self.args.device) for inputs in tqdm(dataloader, desc=description): has_labels = any( inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) for k, v in inputs.items(): inputs[k] = v.to(self.args.device) with torch.no_grad(): outputs = model(**inputs) if has_labels: step_eval_loss, logits = outputs[:2] eval_losses += [step_eval_loss.mean().item()] else: logits = outputs[0] if not prediction_loss_only: if preds is None: preds = logits.detach() else: preds = torch.cat((preds, logits.detach()), dim=0) if inputs.get("labels") is not None: if label_ids is None: label_ids = inputs["labels"].detach() else: label_ids = torch.cat( (label_ids, inputs["labels"].detach()), dim=0) if self.args.local_rank != -1: # In distributed mode, concatenate all results from all nodes: if preds is not None: preds = self.distributed_concat( preds, num_total_examples=self.num_examples(dataloader)) if label_ids is not None: label_ids = self.distributed_concat( label_ids, num_total_examples=self.num_examples(dataloader)) elif is_torch_tpu_available(): # tpu-comment: Get all predictions and labels from all worker shards of eval dataset if preds is not None: preds = xm.mesh_reduce("eval_preds", preds, torch.cat) if label_ids is not None: label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) # Finally, turn the aggregated tensors into numpy arrays. if preds is not None: preds = preds.cpu().numpy() if label_ids is not None: label_ids = label_ids.cpu().numpy() if self.compute_metrics is not None and preds is not None and label_ids is not None: metrics = self.compute_metrics( EvalPrediction(predictions=preds, label_ids=label_ids)) else: metrics = {} if len(eval_losses) > 0: metrics["eval_loss"] = np.mean(eval_losses) # Prefix all keys with eval_ for key in list(metrics.keys()): if not key.startswith("eval_"): metrics[f"eval_{key}"] = metrics.pop(key) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def evaluate(args, model, tokenizer, prefix="", disable_logging=False): """Evaluate the model""" if xm.is_master_ordinal(): # Only master writes to Tensorboard tb_writer = SummaryWriter(args.tensorboard_logdir) # Loop to handle MNLI double evaluation (matched, mis-matched) eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,) results = {} for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) eval_sampler = get_sampler(eval_dataset) if not os.path.exists(eval_output_dir): os.makedirs(eval_output_dir) dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, shuffle=False) eval_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(dataloader) * args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size) eval_loss = 0.0 nb_eval_steps = 0 preds = None out_label_ids = None for batch in tqdm(eval_dataloader, desc="Evaluating", disable=disable_logging): model.eval() with torch.no_grad(): inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} if args.model_type != "distilbert": # XLM, DistilBERT and RoBERTa don't use segment_ids inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None outputs = model(**inputs) batch_eval_loss, logits = outputs[:2] eval_loss += batch_eval_loss nb_eval_steps += 1 if preds is None: preds = logits.detach().cpu().numpy() out_label_ids = inputs["labels"].detach().cpu().numpy() else: preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) # tpu-comment: Get all predictions and labels from all worker shards of eval dataset preds = xm.mesh_reduce("eval_preds", preds, np.concatenate) out_label_ids = xm.mesh_reduce("eval_out_label_ids", out_label_ids, np.concatenate) eval_loss = eval_loss / nb_eval_steps if args.output_mode == "classification": preds = np.argmax(preds, axis=1) elif args.output_mode == "regression": preds = np.squeeze(preds) result = compute_metrics(eval_task, preds, out_label_ids) results.update(result) results["eval_loss"] = eval_loss.item() output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") if xm.is_master_ordinal(): with open(output_eval_file, "w") as writer: logger.info("***** Eval results {} *****".format(prefix)) for key in sorted(results.keys()): logger.info(" %s = %s", key, str(results[key])) writer.write("%s = %s\n" % (key, str(results[key]))) tb_writer.add_scalar(f"{eval_task}/{key}", results[key]) if args.metrics_debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) if xm.is_master_ordinal(): tb_writer.close() return results
def train_fn(df): size = 1; torch.manual_seed(42) df = shuffle(df) split = np.int32(SPLIT*len(df)) val_df, train_df = df[split:], df[:split] val_df = val_df.reset_index(drop=True) val_set = QuoraDataset(val_df, tokenizer) val_sampler = DistributedSampler(val_set, num_replicas=8, rank=xm.get_ordinal(), shuffle=True) train_df = train_df.reset_index(drop=True) train_set = QuoraDataset(train_df, tokenizer) train_sampler = DistributedSampler(train_set, num_replicas=8, rank=xm.get_ordinal(), shuffle=True) val_loader = DataLoader(val_set, VAL_BATCH_SIZE, sampler=val_sampler, num_workers=0, drop_last=True) train_loader = DataLoader(train_set, BATCH_SIZE, sampler=train_sampler, num_workers=0, drop_last=True) device = xm.xla_device() network = Roberta().to(device) optimizer = Adam([{'params': network.roberta.parameters(), 'lr': LR[0]*size}, {'params': network.dense_output.parameters(), 'lr': LR[1]*size}]) val_losses, val_f1s = [], [] train_losses, train_f1s = [], [] start = time.time() xm.master_print("STARTING TRAINING ...\n") for epoch in range(EPOCHS): batch = 1 network.train() fonts = (fg(48), attr('reset')) xm.master_print(("EPOCH %s" + str(epoch+1) + "%s") % fonts) val_parallel = pl.ParallelLoader(val_loader, [device]).per_device_loader(device) train_parallel = pl.ParallelLoader(train_loader, [device]).per_device_loader(device) for train_batch in train_parallel: train_targ, train_in, train_att = train_batch network = network.to(device) train_in = train_in.to(device) train_att = train_att.to(device) train_targ = train_targ.to(device) train_preds = network.forward(train_in, train_att) train_loss = bce(train_preds, train_targ)/len(train_preds) train_f1 = f1_score(train_preds, train_targ.squeeze(dim=1)) optimizer.zero_grad() train_loss.backward() xm.optimizer_step(optimizer) end = time.time() batch = batch + 1 is_print = batch % 10 == 1 f1 = np.round(train_f1.item(), 3) if is_print: print_metric(f1, batch, None, start, end, metric="F1", typ="Train") val_loss, val_f1, val_points = 0, 0, 0 network.eval() with torch.no_grad(): for val_batch in val_parallel: targ, val_in, val_att = val_batch targ = targ.to(device) val_in = val_in.to(device) val_att = val_att.to(device) network = network.to(device) pred = network.forward(val_in, val_att) val_points += len(targ) val_loss += bce(pred, targ).item() val_f1 += f1_score(pred, targ.squeeze(dim=1)).item()*len(pred) end = time.time() val_f1 /= val_points val_loss /= val_points f1 = xm.mesh_reduce('f1', val_f1, lambda x: sum(x)/len(x)) loss = xm.mesh_reduce('loss', val_loss, lambda x: sum(x)/len(x)) print_metric(np.round(f1, 3), None, epoch, start, end, metric="F1", typ="Val") xm.master_print("") val_f1s.append(f1); train_f1s.append(train_f1.item()) val_losses.append(loss); train_losses.append(train_loss.item()) xm.master_print("ENDING TRAINING ...") xm.save(network.state_dict(), MODEL_SAVE_PATH); del network; gc.collect() metric_lists = [val_losses, train_losses, val_f1s, train_f1s] metric_names = ['val_loss_', 'train_loss_', 'val_f1_', 'train_f1_'] for i, metric_list in enumerate(metric_lists): for j, metric_value in enumerate(metric_list): torch.save(metric_value, metric_names[i] + str(j) + '.pt')
def prediction_loop( self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None) -> PredictionOutput: """ Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. Works both with or without labels. """ if hasattr(self, "_prediction_loop"): warnings.warn( "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.", FutureWarning, ) return self._prediction_loop( dataloader, description, prediction_loss_only=prediction_loss_only) prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only model = self.model # multi-gpu eval if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) else: model = self.model # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. batch_size = dataloader.batch_size logger.info("***** Running %s *****", description) logger.info(" Num examples = %d", self.num_examples(dataloader)) logger.info(" Batch size = %d", batch_size) eval_losses: List[float] = [] preds: torch.Tensor = None label_ids: torch.Tensor = None model.eval() if is_torch_tpu_available(): dataloader = pl.ParallelLoader( dataloader, [self.args.device]).per_device_loader(self.args.device) if self.args.past_index >= 0: self._past = None for inputs in tqdm(dataloader, desc=description): loss, logits, labels = self.prediction_step( model, inputs, prediction_loss_only) if loss is not None: eval_losses.append(loss) if logits is not None: preds = logits if preds is None else torch.cat( (preds, logits), dim=0) if labels is not None: label_ids = labels if label_ids is None else torch.cat( (label_ids, labels), dim=0) if self.args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop delattr(self, "_past") if self.args.local_rank != -1: # In distributed mode, concatenate all results from all nodes: if preds is not None: preds = self.distributed_concat( preds, num_total_examples=self.num_examples(dataloader)) if label_ids is not None: label_ids = self.distributed_concat( label_ids, num_total_examples=self.num_examples(dataloader)) elif is_torch_tpu_available(): # tpu-comment: Get all predictions and labels from all worker shards of eval dataset if preds is not None: preds = xm.mesh_reduce("eval_preds", preds, torch.cat) if label_ids is not None: label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) # Finally, turn the aggregated tensors into numpy arrays. if preds is not None: preds = preds.cpu().numpy() if label_ids is not None: label_ids = label_ids.cpu().numpy() if self.compute_metrics is not None and preds is not None and label_ids is not None: metrics = self.compute_metrics( EvalPrediction(predictions=preds, label_ids=label_ids)) else: metrics = {} if len(eval_losses) > 0: metrics["eval_loss"] = np.mean(eval_losses) # Prefix all keys with eval_ for key in list(metrics.keys()): if not key.startswith("eval_"): metrics[f"eval_{key}"] = metrics.pop(key) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)