예제 #1
0
    def evaluate(self,
                 eval_dataset=None,
                 eval_examples=None,
                 ignore_keys=None):
        eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        eval_examples = self.eval_examples if eval_examples is None else eval_examples

        # Temporarily disable metric computation, we will do it in the loop here.
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        try:
            output = self.prediction_loop(
                eval_dataloader,
                description="Evaluation",
                # No point gathering the predictions if there are no metrics, otherwise we defer to
                # self.args.prediction_loss_only
                prediction_loss_only=True if compute_metrics is None else None,
                ignore_keys=ignore_keys,
            )
        finally:
            self.compute_metrics = compute_metrics

        if self.post_process_function is not None and self.compute_metrics is not None:
            eval_preds = self.post_process_function(eval_examples,
                                                    eval_dataset,
                                                    output.predictions)
            metrics = self.compute_metrics(eval_preds)

            self.log(metrics)
        else:
            metrics = {}

        if self.args.tpu_metrics_debug or self.args.debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        self.control = self.callback_handler.on_evaluate(
            self.args, self.state, self.control, metrics)
        return metrics
예제 #2
0
def _mp_fn(rank, args):
    print("rank", rank)
    device = xm.xla_device()
    # devices = (
    #   xm.get_xla_supported_devices(
    #       max_devices=args.num_cores) if args.num_cores != 0 else [])
    # with _LOAD_LOCK:
    #     _MODEL.to(device)
    xm.master_print('done loading model')

    criterion = LabelSmoothedLengthGan_CrossEntropyCriterion(
        args, translation_self.tgt_dict)

    params = list(filter(lambda p: p.requires_grad, _MODEL.parameters()))
    optimizer = FairseqAdam(args, params)
    lr_scheduler = InverseSquareRootSchedule(args, optimizer)

    for epoch in range(args.num_epochs):
        # train_loop_fn(args, _MODEL, criterion, optimizer, device)
        # valid_log = eval_loop_fn(args, _MODEL, criterion, device)
        para_loader = pl.ParallelLoader(valid_dataloader, [device])
        train_loop_fn(para_loader.per_device_loader(device), args, _MODEL,
                      criterion, optimizer, device)
        para_loader = pl.ParallelLoader(valid_dataloader, [device])
        valid_log = eval_loop_fn(para_loader.per_device_loader(device), args,
                                 _MODEL, criterion, device)
        xm.master_print('Finished training epoch {}'.format(epoch))

        xm.master_print(
            "Epoch {}, loss {:.4f}, nll_loss {:.4f}, length_loss {:.4f}, dis_loss {:.4f}"
            .format(epoch, valid_log["loss"], valid_log["nll_loss"],
                    valid_log["length_loss"], valid_log["dis_loss"]))
        lr_scheduler.step(epoch)
        if args.checkpoint_path:
            xm.save(_MODEL.state_dict(), args.checkpoint_path)
예제 #3
0
 def evaluate(self, data_loader, return_predictions=False):
     losses = AverageMeter()
     print_idx = int(len(data_loader) * self.tpu_print / 100)
     self.model.eval()
     final_predictions = []
     with torch.no_grad():
         if self.use_tpu:
             para_loader = pl.ParallelLoader(data_loader, [self.device])
             tk0 = para_loader.per_device_loader(self.device)
         else:
             tk0 = tqdm(data_loader, total=len(data_loader))
         for b_idx, data in enumerate(tk0):
             for key, value in data.items():
                 data[key] = value.to(self.device)
             if self.fp16:
                 with amp.autocast():
                     batch_preds, loss = self.model(**data)
             else:
                 batch_preds, loss = self.model(**data)
             if return_predictions:
                 final_predictions.append(batch_preds)
             if self.use_tpu:
                 reduced_loss = xm.mesh_reduce("loss_reduce", loss,
                                               reduce_fn)
                 losses.update(reduced_loss.item(), data_loader.batch_size)
             else:
                 if self.use_mean_loss:
                     loss = loss.mean()
                 losses.update(loss.item(), data_loader.batch_size)
             if not self.use_tpu:
                 tk0.set_postfix(loss=losses.avg)
             else:
                 if b_idx % print_idx == 0 or b_idx == len(data_loader):
                     xm.master_print(
                         f"{datetime.datetime.now()}: Batch {b_idx} / {len(data_loader)}, loss={losses.avg}"
                     )
         if not self.use_tpu:
             tk0.close()
     return losses.avg, final_predictions
예제 #4
0
    def __init__(self, device, config, steps):
        self.config = config
        self.epoch = 0
        self.steps = steps
        
        self.base_dir = './checkpoints'
        self.log_path = f'{self.base_dir}/log.txt'
        self.best_summary_loss = 10**5

        # get pretrained models
        self.model = Customized_ENSModel(global_config.EfficientNet_Level)
        xm.master_print(">>> Model loaded!")
        self.device = device
        # self.model = self.model.to(device)

        if global_config.LOSS_FN_LabelSmoothing:
            self.criterion = LabelSmoothing()
        else: 
            class_weights = torch.FloatTensor(global_config.CLASS_WEIGHTS)
            self.criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
            print(f">>> Class Weights: {global_config.CLASS_WEIGHTS}")
        self.log(f'>>> Model is loaded. Main Device is {self.device}')
예제 #5
0
def sync_bn1d_no_channel_test(index):
    torch.manual_seed(1)
    bsz = 32
    length = 64
    t_global = torch.rand((xm.xrt_world_size() * bsz, length))

    # XLA SyncBatchNorm
    device = xm.xla_device()
    t_xla = t_global[bsz * index:bsz * (index + 1), ...].to(device)
    sbn_xla = xf.SyncBatchNorm(length).to(device)
    result = run_step(sbn_xla, t_xla)

    # CPU BatchNorm
    bn_cpu = torch.nn.BatchNorm1d(length)
    expected = run_step(bn_cpu, t_global)

    cpu_result = result.cpu()
    assert cpu_result.allclose(expected, rtol=RTOL, atol=ATOL)
    assert_stats(sbn_xla.cpu(), bn_cpu)

    xm.rendezvous('sync_bn1d_no_channel_test')
    xm.master_print('sync_bn1d_no_channel_test ok')
예제 #6
0
def train_loop_fn(train_loader,
                  args,
                  model,
                  criterion,
                  optimizer,
                  device,
                  scheduler=None):
    model.train()
    criterion.train()
    for i, sample in enumerate(train_loader):
        sample = _prepare_sample(sample, device)
        print(sample["target"].shape, sample["target"].device)
        optimizer.zero_grad()
        _, _, logging_output = criterion(model, sample)
        logging = criterion.aggregate_logging_outputs([logging_output])
        loss = logging["loss"]
        loss.backward()
        xm.optimizer_step(optimizer, barrier=True)
        if i % args.log_steps == 0:
            xm.master_print('bi={}, loss={:.4f}'.format(i, loss.item()))
            xm.master_print('MEM: {}'.format(psutil.virtual_memory()))
    print('End training: {}'.format(device))
예제 #7
0
def train_model_xla(net,
                    batch_size,
                    lr,
                    num_epochs,
                    log_steps=20,
                    metrics_debug=False):
    torch.manual_seed(1)

    train_loader, test_loader = load_cifar_10_xla(batch_size)

    # Scale learning rate to num cores
    lr = lr * xm.xrt_world_size()

    # Get loss function, optimizer, and model
    device = xm.xla_device()
    net = net.to(device)
    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    loss_fn = nn.CrossEntropyLoss()

    # Train and eval loops
    accuracy = 0.0
    data, pred, target = None, None, None
    for epoch in range(1, num_epochs + 1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device), net, optimizer,
                      loss_fn, batch_size, log_steps)
        xm.master_print("Finished training epoch {}".format(epoch))

        para_loader = pl.ParallelLoader(test_loader, [device])
        accuracy, data, pred, target = test_loop_fn(
            para_loader.per_device_loader(device), net)
        if metrics_debug:
            xm.master_print(met.metrics_report(), flush=True)

    return accuracy, data, pred, target
예제 #8
0
    def train_one_epoch(self, train_loader):

        self.model.train()
        summary_loss = AverageMeter()
        final_scores = RocAucMeter()
        t = time.time()

        for step, (images, targets) in enumerate(train_loader):

            t0 = time.time()
            batch_size = images.shape[0]
            outputs = self.model(images)

            self.optimizer.zero_grad()
            loss = self.criterion(outputs, targets)
            loss.backward()                         # compute and sum gradients on params
            #torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=global_config.CLIP_GRAD_NORM) 
            
            xm.optimizer_step(self.optimizer)
            if self.config.step_scheduler:
                self.scheduler.step()

            try: 
                final_scores.update(targets, outputs)
            except:
                # xm.master_print("outputs: ", list(outputs.data.cpu().numpy())[:10])
                pass
            summary_loss.update(loss.detach().item(), batch_size)

            if self.config.verbose:
                if step % self.config.verbose_step == 0:

                    t1 = time.time()
                    effNet_lr = np.format_float_scientific(self.optimizer.param_groups[0]['lr'], unique=False, precision=1)
                    head_lr   = np.format_float_scientific(self.optimizer.param_groups[1]['lr'], unique=False, precision=1)
                    xm.master_print(f":::({str(step).rjust(4, ' ')}/{len(train_loader)}) | Loss: {summary_loss.avg:.4f} | AUC: {final_scores.avg:.5f} | LR: {effNet_lr}/{head_lr} | BTime: {t1-t0 :.2f}s | ETime: {int((t1-t0)*(len(train_loader)-step)//60)}m")

        return summary_loss, final_scores
예제 #9
0
    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
        """
        Run evaluation and returns metrics.
        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent (pass it to the init :obj:`compute_metrics` argument).
        Args:
            eval_dataset (:obj:`Dataset`, `optional`):
                Pass a dataset if you wish to override :obj:`self.eval_dataset`.
        Returns:
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self._prediction_loop(eval_dataloader, description="Evaluation")

        self._log(output.metrics)
        logger.info("Evaluate results")
        logger.info(output.metrics)
        if self.args.tpu_metrics_debug or self.args.debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        return output.metrics
    def __call__(self, epoch_score, model, model_path):
        if self.mode == "min":
            score = -1.0 * epoch_score
        else:
            score = np.copy(epoch_score)

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(epoch_score, model, model_path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.tpu:
                xm.master_print("EarlyStopping counter: {} out of {}".format(
                    self.counter, self.patience))
            else:
                print("EarlyStopping counter: {} out of {}".format(
                    self.counter, self.patience))
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(epoch_score, model, model_path)
            self.counter = 0
예제 #11
0
파일: engine.py 프로젝트: L4zyy/DETR_TPU
def tpu_evaluate(model, criterion, postprocessors, data_loader, base_ds,
                 device, output_dir):
    model.eval()
    criterion.eval()

    cnt = 0
    total = len(data_loader)
    for samples, targets in data_loader:
        print('test')
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(samples)
        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict

        loss = loss_dict['loss_giou']

        xm.master_print('Number: {}/{}, Loss:{}'.format(
            cnt + 1, total, loss.item()))
        cnt += 1
        xm.master_print(met.metrics_report())
        exit()
예제 #12
0
    def __init__(self, model, device, config):
        if not os.path.exists('node_submissions'):
            os.makedirs('node_submissions')

        self.config = config
        self.epoch = 0
        self.log_path = 'log.txt'

        self.model = model
        self.device = device

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0008}, #default 0.001
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        self.optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr*xm.xrt_world_size())
        self.scheduler = config.SchedulerClass(self.optimizer, **config.scheduler_params)

        self.criterion = config.criterion
        xm.master_print(f'Fitter prepared. Device is {self.device}')
예제 #13
0
    def train_fn(epoch, train_dataloader, optimizer, criterion, scheduler,
                 device):
        model.train()

        for batch_idx, batch_data in enumerate(train_dataloader):
            optimizer.zero_grad()

            batch_data = any2device(batch_data, device)
            outputs = model(**batch_data)

            y_pred = outputs[OUTPUT_PRED_MODIFICATION_TYPE]
            y_true = batch_data[INPUT_TRUE_MODIFICATION_TYPE]

            loss = criterion(y_pred, y_true)

            if batch_idx % 100:
                xm.master_print(f"Batch: {batch_idx}, loss: {loss.item()}")

            loss.backward()
            xm.optimizer_step(optimizer)

            if scheduler is not None:
                scheduler.step()
예제 #14
0
def sync_bn3d_test(index):
    torch.manual_seed(1)
    bsz = 16
    features = 32
    d, h, w = 16, 32, 32
    t_global = torch.rand((xm.xrt_world_size() * bsz, features, d, h, w))

    # XLA SyncBatchNorm
    device = xm.xla_device()
    t_xla = t_global[bsz * index:bsz * (index + 1), ...].to(device)
    sbn_xla = xf.SyncBatchNorm(features).to(device)
    result = run_step(sbn_xla, t_xla)

    # CPU BatchNorm
    bn_cpu = torch.nn.BatchNorm3d(features)
    expected = run_step(bn_cpu, t_global)

    cpu_result = result.cpu()
    assert cpu_result.allclose(expected, rtol=RTOL, atol=ATOL)
    assert_stats(sbn_xla.cpu(), bn_cpu)

    xm.rendezvous('sync_bn3d_test')
    xm.master_print('sync_bn3d_test ok')
예제 #15
0
    def valid_fn(epoch, valid_dataloader, criterion, device):
        model.eval()

        pred_scores = []
        true_scores = []

        for batch_idx, batch_data in enumerate(valid_dataloader):
            batch_data = any2device(batch_data, device)
            outputs = model(**batch_data)

            y_pred = outputs[OUTPUT_PRED_MODIFICATION_TYPE]
            y_true = batch_data[INPUT_TRUE_MODIFICATION_TYPE]

            loss = criterion(y_pred, y_true)

            pred_scores.extend(to_numpy(parse_classifier_probas(y_pred)))
            true_scores.extend(to_numpy(y_true))

            xm.master_print(f"Batch: {batch_idx}, loss: {loss.item()}")

        val_wauc = alaska_weighted_auc(xla_all_gather(true_scores, device),
                                       xla_all_gather(pred_scores, device))
        xm.master_print(f"Valid epoch: {epoch}, wAUC: {val_wauc}")
        return val_wauc
예제 #16
0
    def prepare_task(args, xla_device):
        # Setup task, e.g., translation, language modeling, etc.
        task = tasks.setup_task(args)

        # Load valid dataset (we load training data below, based on the latest checkpoint)
        for valid_sub_split in args.valid_subset.split(','):
            task.load_dataset(valid_sub_split, combine=True, epoch=0)

        # Build models and criteria to print some metadata
        torch.manual_seed(args.seed)
        model, criterion = task.build_model(args), task.build_criterion(args)
        xm.master_print(model)
        xm.master_print('| model {}, criterion {}'.format(
            args.arch, criterion.__class__.__name__))
        xm.master_print('| num. model params: {} (num. trained: {})'.format(
            sum(p.numel() for p in model.parameters()),
            sum(p.numel() for p in model.parameters() if p.requires_grad)))
        model = model.to(xla_device)
        trainer = Trainer(args, task, model, criterion, xla_device=xla_device)
        lr = trainer.get_lr()

        # Load the latest checkpoint if one is available and restore the
        # corresponding train iterator
        # we overwrite distributed args here to shard data using torch_xla's
        # distributed training.
        trainer.args.distributed_rank = xm.get_ordinal()
        trainer.args.distributed_world_size = xm.xrt_world_size()
        extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
        trainer.args.distributed_rank = 0
        trainer.args.distributed_world_size = 1
        trainer.meters_to_device(xla_device)
        valid_subsets = args.valid_subset.split(',')
        ordinal = xm.get_ordinal(defval=-1)
        device_str = (
            str(xla_device) if ordinal < 0 else
            '{}/{}'.format(xla_device, ordinal)
        )
        return task, trainer, model, epoch_itr, lr, valid_subsets, device_str
예제 #17
0
    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (model_path is not None
                and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
                and os.path.isfile(os.path.join(model_path, "scheduler.pt"))):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"),
                           map_location=self.args.device))
            scheduler.load_state_dict(
                torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(),
                                       metric_dict={})

        # Train!
        if is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size(
            )
        else:
            total_train_batch_size = (self.args.train_batch_size *
                                      self.args.gradient_accumulation_steps *
                                      (torch.distributed.get_world_size()
                                       if self.args.local_rank != -1 else 1))
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d",
                    self.args.per_device_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            self.global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(epochs_trained,
                                int(num_train_epochs),
                                desc="Epoch",
                                disable=not self.is_local_master())
        for epoch in train_iterator:
            if isinstance(train_dataloader, DataLoader) and isinstance(
                    train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(
                    train_dataloader,
                    [self.args.device]).per_device_loader(self.args.device)
                epoch_iterator = tqdm(parallel_loader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())
            else:
                epoch_iterator = tqdm(train_dataloader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())

            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(model, inputs, optimizer)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    if self.args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       self.args.max_grad_norm)

                    if is_torch_tpu_available():
                        xm.optimizer_step(optimizer)
                    else:
                        optimizer.step()

                    scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0
                            and self.global_step % self.args.logging_steps
                            == 0) or (self.global_step == 1
                                      and self.args.logging_first_step):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss -
                                        logging_loss) / self.args.logging_steps
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >=
                            version.parse("1.4") else scheduler.get_lr()[0])
                        logging_loss = tr_loss

                        self._log(logs)

                        if self.args.evaluate_during_training:
                            self.evaluate()

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert model.module is self.model
                        else:
                            assert model is self.model
                        # Save model checkpoint
                        output_dir = os.path.join(
                            self.args.output_dir,
                            f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")

                        self.save_model(output_dir)

                        if self.is_world_master():
                            self._rotate_checkpoints()

                        if is_torch_tpu_available():
                            xm.rendezvous("saving_optimizer_states")
                            xm.save(optimizer.state_dict(),
                                    os.path.join(output_dir, "optimizer.pt"))
                            xm.save(scheduler.state_dict(),
                                    os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_master():
                            torch.save(
                                optimizer.state_dict(),
                                os.path.join(output_dir, "optimizer.pt"))
                            torch.save(
                                scheduler.state_dict(),
                                os.path.join(output_dir, "scheduler.pt"))

                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                train_iterator.close()
                break
            if self.args.tpu_metrics_debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.tb_writer:
            self.tb_writer.close()

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(self.global_step, tr_loss / self.global_step)
예제 #18
0
def train_mnist(flags, state_dict):
    if flags.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(flags.batch_size, 1, 28, 28),
                  torch.zeros(flags.batch_size, dtype=torch.int64)),
            sample_count=60000 // flags.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(flags.batch_size, 1, 28, 28),
                  torch.zeros(flags.batch_size, dtype=torch.int64)),
            sample_count=10000 // flags.batch_size // xm.xrt_world_size())
    else:
        train_dataset = datasets.MNIST(os.path.join(flags.datadir,
                                                    str(xm.get_ordinal())),
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        test_dataset = datasets.MNIST(os.path.join(flags.datadir,
                                                   str(xm.get_ordinal())),
                                      train=False,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=flags.batch_size,
            sampler=train_sampler,
            drop_last=flags.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=flags.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=flags.batch_size,
            drop_last=flags.drop_last,
            shuffle=False,
            num_workers=flags.num_workers)

    # Scale learning rate to num cores
    lr = flags.lr * xm.xrt_world_size()

    device = xm.xla_device()
    model = MNIST()
    model.load_state_dict(state_dict)
    model = model.to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(flags.logdir)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum)
    loss_fn = nn.NLLLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(flags.batch_size)
            if step % flags.log_steps == 0:
                xm.add_step_closure(_train_update,
                                    args=(device, step, loss, tracker, writer),
                                    run_async=FLAGS.async_closures)

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct.item() / total_samples
        # accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)
    accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, flags.num_epochs + 1):
        xm.master_print('Epoch {} train begin {}'.format(
            epoch, test_utils.now()))
        train_loop_fn(train_device_loader)
        xm.master_print('Epoch {} train end {}'.format(epoch,
                                                       test_utils.now()))

        accuracy = test_loop_fn(test_device_loader)
        xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
            epoch, test_utils.now(), accuracy))
        max_accuracy = max(accuracy, max_accuracy)
        test_utils.write_to_summary(writer,
                                    epoch,
                                    dict_to_write={'Accuracy/test': accuracy},
                                    write_xla_metrics=True)
        if flags.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
    return max_accuracy
예제 #19
0
def train_imagenet():
    print('==> Preparing data..')
    img_dim = get_model_property('img_dim')
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler, test_sampler = None, None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            sampler=test_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    device = xm.xla_device()
    model = get_model_property('model_fn')().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(),
                          lr=FLAGS.lr,
                          momentum=FLAGS.momentum,
                          weight_decay=1e-4)
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         xm.xrt_world_size())
    lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
        optimizer,
        scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
        scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
        scheduler_divide_every_n_epochs=getattr(
            FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
        num_steps_per_epoch=num_training_steps_per_epoch,
        summary_writer=writer)
    loss_fn = nn.CrossEntropyLoss()

    def train_loop_fn(loader, epoch):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()
            if step % FLAGS.log_steps == 0:
                xm.add_step_closure(_train_update,
                                    args=(device, step, loss, tracker, epoch,
                                          writer))

    def test_loop_fn(loader, epoch):
        total_samples, correct = 0, 0
        model.eval()
        for step, (data, target) in enumerate(loader):
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]
            if step % FLAGS.log_steps == 0:
                xm.add_step_closure(test_utils.print_test_update,
                                    args=(device, None, epoch, step))
        accuracy = 100.0 * correct.item() / total_samples
        accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
        return accuracy

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    test_device_loader = pl.MpDeviceLoader(test_loader, device)
    accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        xm.master_print('Epoch {} train begin {}'.format(
            epoch, test_utils.now()))
        train_loop_fn(train_device_loader, epoch)
        xm.master_print('Epoch {} train end {}'.format(epoch,
                                                       test_utils.now()))
        accuracy = test_loop_fn(test_device_loader, epoch)
        xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
            epoch, test_utils.now(), accuracy))
        max_accuracy = max(accuracy, max_accuracy)
        test_utils.write_to_summary(writer,
                                    epoch,
                                    dict_to_write={'Accuracy/test': accuracy},
                                    write_xla_metrics=True)
        if FLAGS.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
    return max_accuracy
예제 #20
0
def train(args, train_dataset, model, tokenizer, disable_logging=False):
    """ Train the model """
    if xm.is_master_ordinal():
        # Only master writes to Tensorboard
        tb_writer = SummaryWriter(args.tensorboard_logdir)

    train_sampler = get_sampler(train_dataset)
    dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total,
    )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(dataloader) * args.train_batch_size)
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per TPU core = %d", args.train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        (args.train_batch_size * args.gradient_accumulation_steps * xm.xrt_world_size()),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    loss = None
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=disable_logging)
    set_seed(args.seed)  # Added here for reproductibility (even between python 2 and 3)
    for epoch in train_iterator:
        # tpu-comment: Get TPU parallel loader which sends data to TPU in background.
        train_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", total=len(dataloader), disable=disable_logging)
        for step, batch in enumerate(epoch_iterator):

            # Save model checkpoint.
            if args.save_steps > 0 and global_step % args.save_steps == 0:
                output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                logger.info("Saving model checkpoint to %s", output_dir)

                if xm.is_master_ordinal():
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))

                # Barrier to wait for saving checkpoint.
                xm.rendezvous("mid_training_checkpoint")
                # model.save_pretrained needs to be called by all ordinals
                model.save_pretrained(output_dir)

            model.train()
            inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
            if args.model_type != "distilbert":
                # XLM, DistilBERT and RoBERTa don't use segment_ids
                inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None
            outputs = model(**inputs)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                xm.optimizer_step(optimizer)
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics.
                    results = {}
                    if args.evaluate_during_training:
                        results = evaluate(args, model, tokenizer, disable_logging=disable_logging)
                    loss_scalar = loss.item()
                    logger.info(
                        "global_step: {global_step}, lr: {lr:.6f}, loss: {loss:.3f}".format(
                            global_step=global_step, lr=scheduler.get_lr()[0], loss=loss_scalar
                        )
                    )
                    if xm.is_master_ordinal():
                        # tpu-comment: All values must be in CPU and not on TPU device
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                        tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                        tb_writer.add_scalar("loss", loss_scalar, global_step)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.metrics_debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if xm.is_master_ordinal():
        tb_writer.close()
    return global_step, loss.item()
예제 #21
0
def evaluate(args, model, tokenizer, prefix="", disable_logging=False):
    """Evaluate the model"""
    if xm.is_master_ordinal():
        # Only master writes to Tensorboard
        tb_writer = SummaryWriter(args.tensorboard_logdir)

    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
    eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)

    results = {}
    for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
        eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
        eval_sampler = get_sampler(eval_dataset)

        if not os.path.exists(eval_output_dir):
            os.makedirs(eval_output_dir)

        dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, shuffle=False)
        eval_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)

        # Eval!
        logger.info("***** Running evaluation {} *****".format(prefix))
        logger.info("  Num examples = %d", len(dataloader) * args.eval_batch_size)
        logger.info("  Batch size = %d", args.eval_batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        for batch in tqdm(eval_dataloader, desc="Evaluating", disable=disable_logging):
            model.eval()

            with torch.no_grad():
                inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
                if args.model_type != "distilbert":
                    # XLM, DistilBERT and RoBERTa don't use segment_ids
                    inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None
                outputs = model(**inputs)
                batch_eval_loss, logits = outputs[:2]

                eval_loss += batch_eval_loss
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs["labels"].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)

        # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
        preds = xm.mesh_reduce("eval_preds", preds, np.concatenate)
        out_label_ids = xm.mesh_reduce("eval_out_label_ids", out_label_ids, np.concatenate)

        eval_loss = eval_loss / nb_eval_steps
        if args.output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif args.output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(eval_task, preds, out_label_ids)
        results.update(result)
        results["eval_loss"] = eval_loss.item()

        output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
        if xm.is_master_ordinal():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results {} *****".format(prefix))
                for key in sorted(results.keys()):
                    logger.info("  %s = %s", key, str(results[key]))
                    writer.write("%s = %s\n" % (key, str(results[key])))
                    tb_writer.add_scalar(f"{eval_task}/{key}", results[key])

    if args.metrics_debug:
        # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
        xm.master_print(met.metrics_report())

    if xm.is_master_ordinal():
        tb_writer.close()

    return results
예제 #22
0
    def train(self, data_loader):
        losses = AverageMeter()
        self.model.train()
        print_idx = int(len(data_loader) * self.tpu_print / 100)
        if self.accumulation_steps > 1:
            self.optimizer.zero_grad()
        if self.use_tpu:
            para_loader = pl.ParallelLoader(data_loader, [self.device])
            tk0 = para_loader.per_device_loader(self.device)
        else:
            tk0 = tqdm(data_loader, total=len(data_loader))

        for b_idx, data in enumerate(tk0):
            if self.accumulation_steps == 1 and b_idx == 0:
                self.optimizer.zero_grad()

            if self.model_fn is None:
                for key, value in data.items():
                    data[key] = value.to(self.device)
                _, loss = self.model(**data)
            else:
                if self.fp16:
                    with amp.autocast():
                        loss = self.model_fn(data, self.device, self.model)
                else:
                    loss = self.model_fn(data, self.device, self.model)

            if not self.use_tpu:
                with torch.set_grad_enabled(True):
                    if self.use_mean_loss:
                        loss = loss.mean()

                    if self.fp16:
                        self.scaler.scale(loss).backward()
                    else:
                        loss.backward()

                    if (b_idx + 1) % self.accumulation_steps == 0:
                        if self.fp16:
                            self.scaler.step(self.optimizer)
                        else:
                            self.optimizer.step()

                        if self.scheduler is not None:
                            self.scheduler.step()

                        if b_idx > 0:
                            self.optimizer.zero_grad()

                    if self.fp16:
                        self.scaler.update()
            else:
                loss.backward()
                xm.optimizer_step(self.optimizer)
                if self.scheduler is not None:
                    self.scheduler.step()
                if b_idx > 0:
                    self.optimizer.zero_grad()
            if self.use_tpu:
                reduced_loss = xm.mesh_reduce("loss_reduce", loss, reduce_fn)
                losses.update(reduced_loss.item(), data_loader.batch_size)
            else:
                losses.update(loss.item(), data_loader.batch_size)

            if not self.use_tpu:
                tk0.set_postfix(loss=losses.avg)
            else:
                if b_idx % print_idx == 0 or b_idx == len(data_loader):
                    xm.master_print(
                        f"{datetime.datetime.now()}: Batch {b_idx} / {len(data_loader)}, loss={losses.avg}"
                    )
        if not self.use_tpu:
            tk0.close()
        return losses.avg
예제 #23
0
    def train(x, y):
        G.optim.zero_grad()
        D.optim.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, config['batch_size'])
        y = torch.split(y, config['batch_size'])
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, True)
            utils.toggle_grad(G, False)

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an
            # optimizer step
            D.optim.zero_grad()

            for accumulation_index in range(config['num_D_accumulations']):
                z_, y_ = sample()
                D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']],
                                    x[counter], y[counter], train_G=False,
                                    split_D=config['split_D'])
                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss(
                    D_fake, D_real)
                D_loss = (D_loss_real + D_loss_fake) / \
                    float(config['num_D_accumulations'])
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                xm.master_print('using modified ortho reg in D')
                utils.ortho(D, config['D_ortho'])

            xm.optimizer_step(D.optim)

        # Optionally toggle "requires_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            utils.toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()

        # If accumulating gradients, loop multiple times
        for accumulation_index in range(config['num_G_accumulations']):

            z_, y_ = sample()
            D_fake = GD(z_, y_, train_G=True, split_D=config['split_D'])
            G_loss = losses.generator_loss(
                D_fake) / float(config['num_G_accumulations'])
            G_loss.backward()

        # Optionally apply modified ortho reg in G
        if config['G_ortho'] > 0.0:
            # Debug print to indicate we're using ortho reg in G
            print('using modified ortho reg in G')
            # Don't ortho reg shared, it makes no sense. Really we should
            # blacklist any embeddings for this
  
            utils.ortho(G, config['G_ortho'],
                        blacklist=[param for param in G.shared.parameters()])
        xm.optimizer_step(G.optim)

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
            ema.update(state_dict['itr'])

        out = {'G_loss': G_loss,
               'D_loss_real': D_loss_real,
               'D_loss_fake': D_loss_fake}
        # Return G's loss and the components of D's loss.
        return out
예제 #24
0
def run(index):
    MAX_LEN = 512
    TRAIN_BATCH_SIZE = 16
    EPOCHS = 50
    dfx = pd.read_csv("/home/nizamphoenix/dataset/train.csv").fillna("none")
    df_train, df_valid = model_selection.train_test_split(dfx,
                                                          random_state=42,
                                                          test_size=0.3)
    df_train = df_train.reset_index(drop=True)
    df_valid = df_valid.reset_index(drop=True)

    sample = pd.read_csv("/home/nizamphoenix/dataset/sample_submission.csv")
    target_cols = list(sample.drop("qa_id", axis=1).columns)
    train_targets = df_train[target_cols].values
    valid_targets = df_valid[target_cols].values

    tokenizer = transformers.BertTokenizer.from_pretrained(
        "/home/nizamphoenix/bert-base-uncased/")

    train_dataset = BERTDatasetTraining(qtitle=df_train.question_title.values,
                                        qbody=df_train.question_body.values,
                                        answer=df_train.answer.values,
                                        targets=train_targets,
                                        tokenizer=tokenizer,
                                        max_len=MAX_LEN)

    train_sampler = torch.utils.data.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=TRAIN_BATCH_SIZE, sampler=train_sampler)

    valid_dataset = BERTDatasetTraining(qtitle=df_valid.question_title.values,
                                        qbody=df_valid.question_body.values,
                                        answer=df_valid.answer.values,
                                        targets=valid_targets,
                                        tokenizer=tokenizer,
                                        max_len=MAX_LEN)

    valid_sampler = torch.utils.data.DistributedSampler(
        valid_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
    )
    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=8,  #can make changes here
        sampler=valid_sampler)

    device = xm.xla_device()

    lr = 2e-5 * xm.xrt_world_size()  #can make changes here
    num_train_steps = int(
        len(train_dataset) / TRAIN_BATCH_SIZE / xm.xrt_world_size() * EPOCHS)
    model = BERTBaseUncased("/home/nizamphoenix/bert-base-uncased/").to(device)
    optimizer = AdamW(model.parameters(), lr=lr,
                      eps=1e-8)  #eps = 1e-8: to prevent any division by zero
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=num_train_steps)

    for epoch in range(EPOCHS):
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device), model, optimizer,
                      device, scheduler)

        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        o, t = eval_loop_fn(para_loader.per_device_loader(device), model,
                            device)

        spear = []
        for jj in range(t.shape[1]):
            p1 = list(t[:, jj])
            p2 = list(o[:, jj])
            coef, _ = np.nan_to_num(stats.spearmanr(p1, p2))
            spear.append(coef)
        spear = np.mean(spear)
        xm.master_print(f"epoch={epoch},spearman={spear}")
        xm.save(model.state_dict(), "model3.bin")  #change every time
예제 #25
0
def train_mnist():
    torch.manual_seed(1)

    """
    tpu 를 쓴다하면 dataset 에 할 일
    train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    """
    def get_dataset():
        norm = transforms.Normalize((0.1307,), (0.3081,))
        train_dataset = datasets.MNIST(
            FLAGS['datadir'],
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), norm]))
        test_dataset = datasets.MNIST(
            FLAGS['datadir'],
            train=False,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), norm]))

        return train_dataset, test_dataset

    # Using the serial executor avoids multiple processes to
    # download the same data.
    train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=FLAGS['batch_size'],
        sampler=train_sampler,
        num_workers=FLAGS['num_workers'],
        drop_last=True)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=FLAGS['batch_size'],
        shuffle=False,
        num_workers=FLAGS['num_workers'],
        drop_last=True)

    # Scale learning rate to world size
    lr = FLAGS['learning_rate'] * xm.xrt_world_size()

    # Get loss function, optimizer, and model
    """
    tpu 쓴다하면 device 가 
    device = xm.xla_device()
    model = xmp.MpModelWrapper(MNIST()).to(device)
    """
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS['momentum'])
    loss_fn = nn.NLLLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            # tpu 쓴다하면 optimizer 에 xm.optimizer_step(optimizer)
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS['batch_size'])
            if x % FLAGS['log_steps'] == 0:
                print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
                    xm.get_ordinal(), x, loss.item(), tracker.rate(),
                    tracker.global_rate(), time.asctime()), flush=True)

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        data, pred, target = None, None, None
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        print('[xla:{}] Accuracy={:.2f}%'.format(
            xm.get_ordinal(), accuracy), flush=True)
        return accuracy, data, pred, target

    # Train and eval loops
    accuracy = 0.0
    data, pred, target = None, None, None
    for epoch in range(1, FLAGS['num_epochs'] + 1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print("Finished training epoch {}".format(epoch))

        para_loader = pl.ParallelLoader(test_loader, [device])
        accuracy, data, pred, target = test_loop_fn(para_loader.per_device_loader(device))
        if FLAGS['metrics_debug']:
            xm.master_print(met.metrics_report(), flush=True)

    return accuracy, data, pred, target
예제 #26
0
def train_loop(folds, fold):

    if CFG.device == 'GPU':
        LOGGER.info(f"========== fold: {fold} training ==========")
    elif CFG.device == 'TPU':
        if CFG.nprocs == 1:
            LOGGER.info(f"========== fold: {fold} training ==========")
        elif CFG.nprocs == 8:
            xm.master_print(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)
    valid_labels = valid_folds[CFG.target_cols].values

    train_dataset = TrainDataset(train_folds,
                                 transform=get_transforms(data='train'))
    valid_dataset = TrainDataset(valid_folds,
                                 transform=get_transforms(data='valid'))

    if CFG.device == 'GPU':
        train_loader = DataLoader(train_dataset,
                                  batch_size=CFG.batch_size,
                                  shuffle=True,
                                  num_workers=CFG.num_workers,
                                  pin_memory=True,
                                  drop_last=True)
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=CFG.batch_size * 2,
                                  shuffle=False,
                                  num_workers=CFG.num_workers,
                                  pin_memory=True,
                                  drop_last=False)

    elif CFG.device == 'TPU':
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=CFG.batch_size,
                                                   sampler=train_sampler,
                                                   drop_last=True,
                                                   num_workers=CFG.num_workers)

        valid_sampler = torch.utils.data.distributed.DistributedSampler(
            valid_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False)
        valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                                   batch_size=CFG.batch_size *
                                                   2,
                                                   sampler=valid_sampler,
                                                   drop_last=False,
                                                   num_workers=CFG.num_workers)

    # ====================================================
    # scheduler
    # ====================================================
    def get_scheduler(optimizer):
        if CFG.scheduler == 'ReduceLROnPlateau':
            scheduler = ReduceLROnPlateau(optimizer,
                                          mode='min',
                                          factor=CFG.factor,
                                          patience=CFG.patience,
                                          verbose=True,
                                          eps=CFG.eps)
        elif CFG.scheduler == 'CosineAnnealingLR':
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=CFG.T_max,
                                          eta_min=CFG.min_lr,
                                          last_epoch=-1)
        elif CFG.scheduler == 'CosineAnnealingWarmRestarts':
            scheduler = CosineAnnealingWarmRestarts(optimizer,
                                                    T_0=CFG.T_0,
                                                    T_mult=1,
                                                    eta_min=CFG.min_lr,
                                                    last_epoch=-1)
        return scheduler

    # ====================================================
    # model & optimizer
    # ====================================================
    if CFG.device == 'TPU':
        device = xm.xla_device()
    elif CFG.device == 'GPU':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = CustomResNet200D_WLF(CFG.model_name, pretrained=False)
    model.load_state_dict(
        torch.load(CFG.student, map_location=torch.device('cpu'))['model'])
    model.to(device)

    optimizer = Adam(model.parameters(),
                     lr=CFG.lr,
                     weight_decay=CFG.weight_decay,
                     amsgrad=False)
    scheduler = get_scheduler(optimizer)

    # ====================================================
    # loop
    # ====================================================
    criterion = nn.BCEWithLogitsLoss()

    best_score = 0.
    best_loss = np.inf

    for epoch in range(CFG.epochs):

        start_time = time.time()

        # train
        if CFG.device == 'TPU':
            if CFG.nprocs == 1:
                avg_loss = train_fn(train_loader, model, criterion, optimizer,
                                    epoch, scheduler, device)
            elif CFG.nprocs == 8:
                para_train_loader = pl.ParallelLoader(train_loader, [device])
                avg_loss = train_fn(
                    para_train_loader.per_device_loader(device), model,
                    criterion, optimizer, epoch, scheduler, device)
        elif CFG.device == 'GPU':
            avg_loss = train_fn(train_loader, model, criterion, optimizer,
                                epoch, scheduler, device)

        # eval
        if CFG.device == 'TPU':
            if CFG.nprocs == 1:
                avg_val_loss, preds, _ = valid_fn(valid_loader, model,
                                                  criterion, device)
            elif CFG.nprocs == 8:
                para_valid_loader = pl.ParallelLoader(valid_loader, [device])
                avg_val_loss, preds, valid_labels = valid_fn(
                    para_valid_loader.per_device_loader(device), model,
                    criterion, device)
                preds = idist.all_gather(torch.tensor(preds)).to('cpu').numpy()
                valid_labels = idist.all_gather(
                    torch.tensor(valid_labels)).to('cpu').numpy()
        elif CFG.device == 'GPU':
            avg_val_loss, preds, _ = valid_fn(valid_loader, model, criterion,
                                              device)

        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(avg_val_loss)
        elif isinstance(scheduler, CosineAnnealingLR):
            scheduler.step()
        elif isinstance(scheduler, CosineAnnealingWarmRestarts):
            scheduler.step()

        # scoring
        score, scores = get_score(valid_labels, preds)

        elapsed = time.time() - start_time

        if CFG.device == 'GPU':
            LOGGER.info(
                f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s'
            )
            LOGGER.info(
                f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}'
            )
        elif CFG.device == 'TPU':
            if CFG.nprocs == 1:
                LOGGER.info(
                    f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s'
                )
                LOGGER.info(
                    f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}'
                )
            elif CFG.nprocs == 8:
                xm.master_print(
                    f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s'
                )
                xm.master_print(
                    f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}'
                )

        if score > best_score:
            best_score = score
            if CFG.device == 'GPU':
                LOGGER.info(
                    f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model'
                )
                torch.save({
                    'model': model.state_dict(),
                    'preds': preds
                }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth')
            elif CFG.device == 'TPU':
                if CFG.nprocs == 1:
                    LOGGER.info(
                        f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model'
                    )
                elif CFG.nprocs == 8:
                    xm.master_print(
                        f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model'
                    )
                xm.save({
                    'model': model.state_dict(),
                    'preds': preds
                }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth')

        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            if CFG.device == 'GPU':
                LOGGER.info(
                    f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model')
                torch.save({
                    'model': model.state_dict(),
                    'preds': preds
                }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_loss.pth')
            elif CFG.device == 'TPU':
                if CFG.nprocs == 1:
                    LOGGER.info(
                        f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model'
                    )
                elif CFG.nprocs == 8:
                    xm.master_print(
                        f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model'
                    )
                xm.save({
                    'model': model.state_dict(),
                    'preds': preds
                }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_loss.pth')

        if CFG.nprocs != 8:
            check_point = torch.load(
                OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth')
            for c in [f'pred_{c}' for c in CFG.target_cols]:
                valid_folds[c] = np.nan
            valid_folds[[f'pred_{c}'
                         for c in CFG.target_cols]] = check_point['preds']

    return valid_folds
예제 #27
0
def valid_fn(valid_loader, model, criterion, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to evaluation mode
    model.eval()
    trues = []
    preds = []
    start = end = time.time()
    for step, (images, labels) in enumerate(valid_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        # compute loss
        with torch.no_grad():
            _, _, y_preds = model(images)
        loss = criterion(y_preds, labels)
        losses.update(loss.item(), batch_size)
        # record accuracy
        trues.append(labels.to('cpu').numpy())
        preds.append(y_preds.sigmoid().to('cpu').numpy())
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if CFG.device == 'GPU':
            if step % CFG.print_freq == 0 or step == (len(valid_loader) - 1):
                print('EVAL: [{0}/{1}] '
                      'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                      'Elapsed {remain:s} '
                      'Loss: {loss.val:.4f}({loss.avg:.4f}) '.format(
                          step,
                          len(valid_loader),
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses,
                          remain=timeSince(start,
                                           float(step + 1) /
                                           len(valid_loader)),
                      ))
        elif CFG.device == 'TPU':
            if step % CFG.print_freq == 0 or step == (len(valid_loader) - 1):
                xm.master_print(
                    'EVAL: [{0}/{1}] '
                    'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                    'Elapsed {remain:s} '
                    'Loss: {loss.val:.4f}({loss.avg:.4f}) '.format(
                        step,
                        len(valid_loader),
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=losses,
                        remain=timeSince(start,
                                         float(step + 1) / len(valid_loader)),
                    ))
    trues = np.concatenate(trues)
    predictions = np.concatenate(preds)
    return losses.avg, predictions, trues
예제 #28
0
def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler,
             device):
    if CFG.device == 'GPU':
        scaler = GradScaler()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to train mode
    model.train()
    start = end = time.time()
    global_step = 0
    for step, (images, labels) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        if CFG.device == 'GPU':
            with autocast():
                _, _, y_preds = model(images)
                loss = criterion(y_preds, labels)
                # record loss
                losses.update(loss.item(), batch_size)
                if CFG.gradient_accumulation_steps > 1:
                    loss = loss / CFG.gradient_accumulation_steps
                scaler.scale(loss).backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), CFG.max_grad_norm)
                if (step + 1) % CFG.gradient_accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    global_step += 1

        elif CFG.device == 'TPU':
            _, _, y_preds = model(images)
            loss = criterion(y_preds, labels)
            # record loss
            losses.update(loss.item(), batch_size)
            if CFG.gradient_accumulation_steps > 1:
                loss = loss / CFG.gradient_accumulation_steps
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       CFG.max_grad_norm)
            if (step + 1) % CFG.gradient_accumulation_steps == 0:
                xm.optimizer_step(optimizer, barrier=True)
                optimizer.zero_grad()
                global_step += 1
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if CFG.device == 'GPU':
            if step % CFG.print_freq == 0 or step == (len(train_loader) - 1):
                print('Epoch: [{0}][{1}/{2}] '
                      'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                      'Elapsed {remain:s} '
                      'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                      'Grad: {grad_norm:.4f}  '
                      #'LR: {lr:.6f}  '
                      .format(
                       epoch+1, step, len(train_loader), batch_time=batch_time,
                       data_time=data_time, loss=losses,
                       remain=timeSince(start, float(step+1)/len(train_loader)),
                       grad_norm=grad_norm,
                       #lr=scheduler.get_lr()[0],
                       ))
        elif CFG.device == 'TPU':
            if step % CFG.print_freq == 0 or step == (len(train_loader) - 1):
                xm.master_print('Epoch: [{0}][{1}/{2}] '
                                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                                'Elapsed {remain:s} '
                                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                                'Grad: {grad_norm:.4f}  '
                                #'LR: {lr:.6f}  '
                                .format(
                                epoch+1, step, len(train_loader), batch_time=batch_time,
                                data_time=data_time, loss=losses,
                                remain=timeSince(start, float(step+1)/len(train_loader)),
                                grad_norm=grad_norm,
                                #lr=scheduler.get_lr()[0],
                                ))
    return losses.avg
예제 #29
0
def train_tpu():
    torch.manual_seed(1)

    def get_dataset():
        fold_number = 0
        train_ = pd.read_csv(args.train_fold)
        train = ShopeeDataset(
            train_[train_['fold'] != fold_number].reset_index(drop=True))
        test = ShopeeDataset(
            train_[train_['fold'] != fold_number].reset_index(drop=True),
            transform=args.test_args)
        return train, test

    # Using the serial executor avoids multiple processes
    # to download the same data.
    train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               sampler=train_sampler,
                                               num_workers=args.num_workers,
                                               drop_last=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              drop_last=True)

    # Scale learning rate to num cores
    learning_rate = 1e-5 * xm.xrt_world_size()

    # Get loss function, optimizer, and model
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.CrossEntropyLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, (data, label) in enumerate(loader):
            optimizer.zero_grad()
            output = model(image=data,
                           label=label,
                           get_embedding=args.get_embeddings)
            loss = loss_fn(output, label)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(args.batch_size)
            if x % 20 == 0:
                print(
                    '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'
                    .format(xm.get_ordinal(), x, loss.item(), tracker.rate(),
                            tracker.global_rate(), time.asctime()),
                    flush=True)

    def test_loop_fn(loader):
        model.eval()
        for x, (data, label) in enumerate(loader):
            output = model(image=data,
                           label=label,
                           get_embedding=args.get_embeddings)
            loss = loss_fn(output, label)
            if x % 20 == 0:
                print('[xla:{}]({}) Loss={:.5f}'.format(
                    xm.get_ordinal(), x, loss.item()),
                      flush=True)

    for epoch in range(1, args.n_epochs + 1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print("Finished training epoch {}".format(epoch))

        para_loader = pl.ParallelLoader(test_loader, [device])
        test_loop_fn(para_loader.per_device_loader(device))
def train_fn(df):
    size = 1; torch.manual_seed(42)

    df = shuffle(df)
    split = np.int32(SPLIT*len(df))
    val_df, train_df = df[split:], df[:split]

    val_df = val_df.reset_index(drop=True)
    val_set = QuoraDataset(val_df, tokenizer)
    val_sampler = DistributedSampler(val_set, num_replicas=8,
                                     rank=xm.get_ordinal(), shuffle=True)

    train_df = train_df.reset_index(drop=True)
    train_set = QuoraDataset(train_df, tokenizer)
    train_sampler = DistributedSampler(train_set, num_replicas=8,
                                       rank=xm.get_ordinal(), shuffle=True)
    
    val_loader = DataLoader(val_set, VAL_BATCH_SIZE,
                            sampler=val_sampler, num_workers=0, drop_last=True)

    train_loader = DataLoader(train_set, BATCH_SIZE,
                              sampler=train_sampler, num_workers=0, drop_last=True)

    device = xm.xla_device()
    network = Roberta().to(device)
    optimizer = Adam([{'params': network.roberta.parameters(), 'lr': LR[0]*size},
                      {'params': network.dense_output.parameters(), 'lr': LR[1]*size}])

    val_losses, val_f1s = [], []
    train_losses, train_f1s = [], []
    
    start = time.time()
    xm.master_print("STARTING TRAINING ...\n")

    for epoch in range(EPOCHS):

        batch = 1
        network.train()
        fonts = (fg(48), attr('reset'))
        xm.master_print(("EPOCH %s" + str(epoch+1) + "%s") % fonts)

        val_parallel = pl.ParallelLoader(val_loader, [device]).per_device_loader(device)
        train_parallel = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
        
        for train_batch in train_parallel:
            train_targ, train_in, train_att = train_batch
            
            network = network.to(device)
            train_in = train_in.to(device)
            train_att = train_att.to(device)
            train_targ = train_targ.to(device)

            train_preds = network.forward(train_in, train_att)
            train_loss = bce(train_preds, train_targ)/len(train_preds)
            train_f1 = f1_score(train_preds, train_targ.squeeze(dim=1))

            optimizer.zero_grad()
            train_loss.backward()
            xm.optimizer_step(optimizer)
            
            end = time.time()
            batch = batch + 1
            is_print = batch % 10 == 1
            f1 = np.round(train_f1.item(), 3)
            if is_print: print_metric(f1, batch, None, start, end, metric="F1", typ="Train")

        val_loss, val_f1, val_points = 0, 0, 0

        network.eval()
        with torch.no_grad():
            for val_batch in val_parallel:
                targ, val_in, val_att = val_batch

                targ = targ.to(device)
                val_in = val_in.to(device)
                val_att = val_att.to(device)
                network = network.to(device)
                pred = network.forward(val_in, val_att)

                val_points += len(targ)
                val_loss += bce(pred, targ).item()
                val_f1 += f1_score(pred, targ.squeeze(dim=1)).item()*len(pred)
        
        end = time.time()
        val_f1 /= val_points
        val_loss /= val_points
        f1 = xm.mesh_reduce('f1', val_f1, lambda x: sum(x)/len(x))
        loss = xm.mesh_reduce('loss', val_loss, lambda x: sum(x)/len(x))
        print_metric(np.round(f1, 3), None, epoch, start, end, metric="F1", typ="Val")
    
        xm.master_print("")
        val_f1s.append(f1); train_f1s.append(train_f1.item())
        val_losses.append(loss); train_losses.append(train_loss.item())

    xm.master_print("ENDING TRAINING ...")
    xm.save(network.state_dict(), MODEL_SAVE_PATH); del network; gc.collect()
    
    metric_lists = [val_losses, train_losses, val_f1s, train_f1s]
    metric_names = ['val_loss_', 'train_loss_', 'val_f1_', 'train_f1_']
    
    for i, metric_list in enumerate(metric_lists):
        for j, metric_value in enumerate(metric_list):
            torch.save(metric_value, metric_names[i] + str(j) + '.pt')