Esempio n. 1
0
    def run_training_epoch(self):

        # get model
        model = self.get_model()

        # Epoch start events
        with self.profiler.profile('on_epoch_start'):
            # callbacks
            self.on_epoch_start()

            # model hooks
            if self.is_function_implemented('on_epoch_start'):
                model.on_epoch_start()

        # track local dataloader so TPU can wrap each epoch
        train_dataloader = self.train_dataloader

        # on TPU we have to wrap it under the ParallelLoader
        if self.use_tpu:
            device = xm.xla_device()
            train_dataloader = xla_pl.ParallelLoader(train_dataloader,
                                                     [device])
            train_dataloader = train_dataloader.per_device_loader(device)

        # bookkeeping
        outputs = []

        # run epoch
        for batch_idx, (batch,
                        is_last_batch) in self.profiler.profile_iterable(
                            enumerate(_with_is_last(train_dataloader)),
                            "get_train_batch"):
            # stop epoch if we limited the number of training batches
            if batch_idx >= self.num_training_batches:
                break

            self.batch_idx = batch_idx

            model.global_step = self.global_step

            # ---------------
            # RUN TRAIN STEP
            # ---------------
            _outputs = self.run_training_batch(batch, batch_idx)
            batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs
            # detach tensors in batch_output before appending to outputs
            outputs.append(_recursive_detach(batch_output))

            # when returning -1 from train_step, we end epoch early
            early_stop_epoch = batch_result == -1

            # update lr
            self.update_learning_rates(interval='step')

            # ---------------
            # RUN VAL STEP
            # ---------------
            is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
            can_check_epoch = (self.current_epoch +
                               1) % self.check_val_every_n_epoch == 0
            can_check_val = not self.disable_validation and can_check_epoch
            should_check_val = is_val_check_batch or early_stop_epoch
            should_check_val = should_check_val or (
                is_last_batch and self.val_check_batch == float('inf'))
            should_check_val = can_check_val and should_check_val

            # fast_dev_run always forces val checking after train batch
            if self.fast_dev_run or should_check_val:
                self.run_evaluation(test_mode=self.testing)

            # when logs should be saved
            should_save_log = (
                batch_idx +
                1) % self.log_save_interval == 0 or early_stop_epoch
            if should_save_log or self.fast_dev_run:
                if self.proc_rank == 0 and self.logger is not None:
                    self.logger.save()

            # when metrics should be logged
            should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
            if should_log_metrics or self.fast_dev_run:
                # logs user requested information to logger
                self.log_metrics(batch_step_metrics, grad_norm_dic)

            # ---------------
            # CHECKPOINTING, EARLY STOPPING
            # ---------------
            # save checkpoint even when no test or val step are defined
            if self.fast_dev_run or should_check_val:
                self.call_checkpoint_callback()

                if self.enable_early_stop:
                    self.early_stop_callback.check_metrics(
                        self.callback_metrics)

            # progress global step according to grads progress
            if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
                self.global_step += 1
            self.total_batch_idx += 1

            # max steps reached, end training
            if self.max_steps is not None and self.max_steps == self.global_step:
                break

            # end epoch early
            # stop when the flag is changed or we've gone past the amount
            # requested in the batches
            if early_stop_epoch or self.fast_dev_run:
                break

        # process epoch outputs
        if isinstance(
                model,
            (LightningDistributedDataParallel, LightningDataParallel)):
            model = model.module

        if self.is_overriden('training_epoch_end', model=model):
            epoch_output = model.training_epoch_end(outputs)
            _processed_outputs = self.process_output(epoch_output)
            log_epoch_metrics = _processed_outputs[2]
            callback_epoch_metrics = _processed_outputs[3]
            self.log_metrics(log_epoch_metrics, {})
            self.callback_metrics.update(callback_epoch_metrics)

        # in case validation step is missing and you are not running fast-dev to duplicate last batch
        if not self.is_overriden('validation_step') and not (
                self.fast_dev_run or should_check_val):
            self.call_checkpoint_callback()

            if self.enable_early_stop:
                self.early_stop_callback.check_metrics(self.callback_metrics)

        # Epoch end events
        with self.profiler.profile('on_epoch_end'):
            # callbacks
            self.on_epoch_end()
            # model hooks
            if self.is_function_implemented('on_epoch_end'):
                model.on_epoch_end()
Esempio n. 2
0
    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (model_path is not None
                and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
                and os.path.isfile(os.path.join(model_path, "scheduler.pt"))):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"),
                           map_location=self.args.device))
            scheduler.load_state_dict(
                torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(),
                                       metric_dict={})

        # Train!
        if is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size(
            )
        else:
            total_train_batch_size = (self.args.train_batch_size *
                                      self.args.gradient_accumulation_steps *
                                      (torch.distributed.get_world_size()
                                       if self.args.local_rank != -1 else 1))
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d",
                    self.args.per_device_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization stepss = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            self.global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(epochs_trained,
                                int(num_train_epochs),
                                desc="Epoch",
                                disable=not self.is_local_master())
        for epoch in train_iterator:
            if isinstance(train_dataloader, DataLoader) and isinstance(
                    train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(
                    train_dataloader,
                    [self.args.device]).per_device_loader(self.args.device)
                epoch_iterator = tqdm(parallel_loader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())
            else:
                epoch_iterator = tqdm(train_dataloader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())

            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self.training_step(model, inputs, optimizer)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    if self.args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       self.args.max_grad_norm)

                    if is_torch_tpu_available():
                        xm.optimizer_step(optimizer)
                    else:
                        optimizer.step()

                    scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0
                            and self.global_step % self.args.logging_steps
                            == 0) or (self.global_step == 1
                                      and self.args.logging_first_step):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss -
                                        logging_loss) / self.args.logging_steps
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >=
                            version.parse("1.4") else scheduler.get_lr()[0])
                        logging_loss = tr_loss

                        self.log(logs)

                    if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
                        self.evaluate()

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert model.module is self.model
                        else:
                            assert model is self.model
                        # Save model checkpoint
                        output_dir = os.path.join(
                            self.args.output_dir,
                            f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")

                        self.save_model(output_dir)

                        if self.is_world_master():
                            self._rotate_checkpoints()

                        if is_torch_tpu_available():
                            xm.rendezvous("saving_optimizer_states")
                            xm.save(optimizer.state_dict(),
                                    os.path.join(output_dir, "optimizer.pt"))
                            xm.save(scheduler.state_dict(),
                                    os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_master():
                            torch.save(
                                optimizer.state_dict(),
                                os.path.join(output_dir, "optimizer.pt"))
                            torch.save(
                                scheduler.state_dict(),
                                os.path.join(output_dir, "scheduler.pt"))

                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                train_iterator.close()
                break
            if self.args.tpu_metrics_debug or self.args.debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.tb_writer:
            self.tb_writer.close()
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(self.global_step, tr_loss / self.global_step)
Esempio n. 3
0
def train_loop(folds, fold):

    if CFG.device == 'GPU':
        LOGGER.info(f"========== fold: {fold} training ==========")
    elif CFG.device == 'TPU':
        if CFG.nprocs == 1:
            LOGGER.info(f"========== fold: {fold} training ==========")
        elif CFG.nprocs == 8:
            xm.master_print(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)
    valid_labels = valid_folds[CFG.target_cols].values

    train_dataset = TrainDataset(train_folds,
                                 transform=get_transforms(data='train'))
    valid_dataset = TrainDataset(valid_folds,
                                 transform=get_transforms(data='valid'))

    if CFG.device == 'GPU':
        train_loader = DataLoader(train_dataset,
                                  batch_size=CFG.batch_size,
                                  shuffle=True,
                                  num_workers=CFG.num_workers,
                                  pin_memory=True,
                                  drop_last=True)
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=CFG.batch_size * 2,
                                  shuffle=False,
                                  num_workers=CFG.num_workers,
                                  pin_memory=True,
                                  drop_last=False)

    elif CFG.device == 'TPU':
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=CFG.batch_size,
                                                   sampler=train_sampler,
                                                   drop_last=True,
                                                   num_workers=CFG.num_workers)

        valid_sampler = torch.utils.data.distributed.DistributedSampler(
            valid_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False)
        valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                                   batch_size=CFG.batch_size *
                                                   2,
                                                   sampler=valid_sampler,
                                                   drop_last=False,
                                                   num_workers=CFG.num_workers)

    # ====================================================
    # scheduler
    # ====================================================
    def get_scheduler(optimizer):
        if CFG.scheduler == 'ReduceLROnPlateau':
            scheduler = ReduceLROnPlateau(optimizer,
                                          mode='min',
                                          factor=CFG.factor,
                                          patience=CFG.patience,
                                          verbose=True,
                                          eps=CFG.eps)
        elif CFG.scheduler == 'CosineAnnealingLR':
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=CFG.T_max,
                                          eta_min=CFG.min_lr,
                                          last_epoch=-1)
        elif CFG.scheduler == 'CosineAnnealingWarmRestarts':
            scheduler = CosineAnnealingWarmRestarts(optimizer,
                                                    T_0=CFG.T_0,
                                                    T_mult=1,
                                                    eta_min=CFG.min_lr,
                                                    last_epoch=-1)
        return scheduler

    # ====================================================
    # model & optimizer
    # ====================================================
    if CFG.device == 'TPU':
        device = xm.xla_device()
    elif CFG.device == 'GPU':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = CustomResNet200D_WLF(CFG.model_name, pretrained=False)
    model.load_state_dict(
        torch.load(CFG.student, map_location=torch.device('cpu'))['model'])
    model.to(device)

    optimizer = Adam(model.parameters(),
                     lr=CFG.lr,
                     weight_decay=CFG.weight_decay,
                     amsgrad=False)
    scheduler = get_scheduler(optimizer)

    # ====================================================
    # loop
    # ====================================================
    criterion = nn.BCEWithLogitsLoss()

    best_score = 0.
    best_loss = np.inf

    for epoch in range(CFG.epochs):

        start_time = time.time()

        # train
        if CFG.device == 'TPU':
            if CFG.nprocs == 1:
                avg_loss = train_fn(train_loader, model, criterion, optimizer,
                                    epoch, scheduler, device)
            elif CFG.nprocs == 8:
                para_train_loader = pl.ParallelLoader(train_loader, [device])
                avg_loss = train_fn(
                    para_train_loader.per_device_loader(device), model,
                    criterion, optimizer, epoch, scheduler, device)
        elif CFG.device == 'GPU':
            avg_loss = train_fn(train_loader, model, criterion, optimizer,
                                epoch, scheduler, device)

        # eval
        if CFG.device == 'TPU':
            if CFG.nprocs == 1:
                avg_val_loss, preds, _ = valid_fn(valid_loader, model,
                                                  criterion, device)
            elif CFG.nprocs == 8:
                para_valid_loader = pl.ParallelLoader(valid_loader, [device])
                avg_val_loss, preds, valid_labels = valid_fn(
                    para_valid_loader.per_device_loader(device), model,
                    criterion, device)
                preds = idist.all_gather(torch.tensor(preds)).to('cpu').numpy()
                valid_labels = idist.all_gather(
                    torch.tensor(valid_labels)).to('cpu').numpy()
        elif CFG.device == 'GPU':
            avg_val_loss, preds, _ = valid_fn(valid_loader, model, criterion,
                                              device)

        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(avg_val_loss)
        elif isinstance(scheduler, CosineAnnealingLR):
            scheduler.step()
        elif isinstance(scheduler, CosineAnnealingWarmRestarts):
            scheduler.step()

        # scoring
        score, scores = get_score(valid_labels, preds)

        elapsed = time.time() - start_time

        if CFG.device == 'GPU':
            LOGGER.info(
                f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s'
            )
            LOGGER.info(
                f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}'
            )
        elif CFG.device == 'TPU':
            if CFG.nprocs == 1:
                LOGGER.info(
                    f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s'
                )
                LOGGER.info(
                    f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}'
                )
            elif CFG.nprocs == 8:
                xm.master_print(
                    f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s'
                )
                xm.master_print(
                    f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}'
                )

        if score > best_score:
            best_score = score
            if CFG.device == 'GPU':
                LOGGER.info(
                    f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model'
                )
                torch.save({
                    'model': model.state_dict(),
                    'preds': preds
                }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth')
            elif CFG.device == 'TPU':
                if CFG.nprocs == 1:
                    LOGGER.info(
                        f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model'
                    )
                elif CFG.nprocs == 8:
                    xm.master_print(
                        f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model'
                    )
                xm.save({
                    'model': model.state_dict(),
                    'preds': preds
                }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth')

        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            if CFG.device == 'GPU':
                LOGGER.info(
                    f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model')
                torch.save({
                    'model': model.state_dict(),
                    'preds': preds
                }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_loss.pth')
            elif CFG.device == 'TPU':
                if CFG.nprocs == 1:
                    LOGGER.info(
                        f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model'
                    )
                elif CFG.nprocs == 8:
                    xm.master_print(
                        f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model'
                    )
                xm.save({
                    'model': model.state_dict(),
                    'preds': preds
                }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_loss.pth')

        if CFG.nprocs != 8:
            check_point = torch.load(
                OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth')
            for c in [f'pred_{c}' for c in CFG.target_cols]:
                valid_folds[c] = np.nan
            valid_folds[[f'pred_{c}'
                         for c in CFG.target_cols]] = check_point['preds']

    return valid_folds
    def run_training_epoch(self):

        # get model
        model = self.get_model()

        # Epoch start events
        with self.profiler.profile('on_epoch_start'):
            # callbacks
            self.on_epoch_start()

            # model hooks
            if self.is_function_implemented('on_epoch_start'):
                model.on_epoch_start()

        # track local dataloader so TPU can wrap each epoch
        train_dataloader = self.train_dataloader

        # on TPU we have to wrap it under the ParallelLoader
        if self.use_tpu:
            device = xm.xla_device(self.tpu_id)
            train_dataloader = xla_pl.ParallelLoader(train_dataloader,
                                                     [device])
            train_dataloader = train_dataloader.per_device_loader(device)

        # bookkeeping
        outputs = []

        # run epoch
        for batch_idx, (batch,
                        is_last_batch) in self.profiler.profile_iterable(
                            enumerate(_with_is_last(train_dataloader)),
                            "get_train_batch"):
            # stop epoch if we limited the number of training batches
            if batch_idx >= self.num_training_batches:
                break

            self.batch_idx = batch_idx

            model.global_step = self.global_step

            # ---------------
            # RUN TRAIN STEP
            # ---------------
            _outputs = self.run_training_batch(batch, batch_idx)
            batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs

            # only track outputs when user implements training_epoch_end
            # otherwise we will build up unnecessary memory
            if self.is_overridden('training_epoch_end',
                                  model=self.get_model()):
                outputs.append(batch_output)

            # when returning -1 from train_step, we end epoch early
            early_stop_epoch = batch_result == -1

            # TODO: consolidate all actions that need to take place only after
            # self.accumulate_grad_batches steps (optimizer step, lr update, global step increment)
            if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
                # update lr
                self.update_learning_rates(interval='step')

            # ---------------
            # RUN VAL STEP
            # ---------------
            is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
            can_check_epoch = (self.current_epoch +
                               1) % self.check_val_every_n_epoch == 0
            can_check_val = not self.disable_validation and can_check_epoch
            should_check_val = is_val_check_batch or early_stop_epoch
            should_check_val = should_check_val or (
                is_last_batch and self.val_check_batch == float('inf'))
            should_check_val = can_check_val and should_check_val

            # ---------------
            # CHECKPOINTING, EARLY STOPPING
            # ---------------
            # fast_dev_run always forces val checking after train batch
            if self.fast_dev_run or should_check_val:
                self.run_evaluation(test_mode=self.testing)
                self.call_checkpoint_callback()

            # when logs should be saved
            should_save_log = (
                batch_idx +
                1) % self.log_save_interval == 0 or early_stop_epoch
            if should_save_log or self.fast_dev_run:
                if self.proc_rank == 0 and self.logger is not None:
                    self.logger.save()

            # when metrics should be logged
            should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
            if should_log_metrics or self.fast_dev_run:
                # logs user requested information to logger
                self.log_metrics(batch_step_metrics, grad_norm_dic)

            # progress global step according to grads progress
            if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
                self.global_step += 1
            self.total_batch_idx += 1

            # max steps reached, end training
            if self.max_steps is not None and self.max_steps == self.global_step:
                break

            # end epoch early
            # stop when the flag is changed or we've gone past the amount
            # requested in the batches
            if early_stop_epoch or self.fast_dev_run:
                break

        if self.use_horovod:
            hvd.join(hvd.local_rank() if self.on_gpu else -1)

        # process epoch outputs
        model = self.get_model()
        if self.is_overridden('training_epoch_end', model=model):
            epoch_output = model.training_epoch_end(outputs)
            _processed_outputs = self.process_output(epoch_output)
            log_epoch_metrics = _processed_outputs[2]
            callback_epoch_metrics = _processed_outputs[3]
            self.log_metrics(log_epoch_metrics, {})
            self.callback_metrics.update(callback_epoch_metrics)
            self.add_progress_bar_metrics(_processed_outputs[1])

        # when no val loop is present or fast-dev-run still need to call checkpoints
        if not self.is_overridden('validation_step') and not (
                self.fast_dev_run or should_check_val):
            self.call_checkpoint_callback()

        # Epoch end events
        with self.profiler.profile('on_epoch_end'):
            # callbacks
            self.on_epoch_end()
            # model hooks
            if self.is_function_implemented('on_epoch_end'):
                model.on_epoch_end()
 def process_dataloader(self, dataloader):
     device = xm.xla_device(self.trainer.tpu_id)
     dataloader = xla_pl.ParallelLoader(dataloader, [device])
     dataloader = dataloader.per_device_loader(device)
     return dataloader
Esempio n. 6
0
def main(config_file='config/bert_config.json'):
    """Main method for training.

    Args:
        config_file: in config dir
    """
    global datasets
    # 0. Load config and mkdir
    with open(config_file) as fin:
        config = json.load(fin, object_hook=lambda d: SimpleNamespace(**d))

    get_path(os.path.join(config.model_path, config.experiment_name))
    get_path(config.log_path)
    if config.model_type in ['rnn', 'lr', 'cnn']:  # build vocab for rnn
        build_vocab(file_in=config.all_train_file_path,
                    file_out=os.path.join(config.model_path, 'vocab.txt'))
    # 1. Load data
    data = Data(vocab_file=os.path.join(config.model_path, 'vocab.txt'),
                max_seq_len=config.max_seq_len,
                model_type=config.model_type,
                config=config)

    def load_dataset():
        datasets = data.load_train_and_valid_files(
            train_file=config.train_file_path,
            valid_file=config.valid_file_path)
        return datasets

    if config.serial_load:
        datasets = SERIAL_EXEC.run(load_dataset)
    else:
        datasets = load_dataset()

    train_set, valid_set_train, valid_set_valid = datasets
    if torch.cuda.is_available():
        device = torch.device('cuda')
        # device = torch.device('cpu')
        # torch.distributed.init_process_group(backend="nccl")
        # sampler_train = DistributedSampler(train_set)
        sampler_train = RandomSampler(train_set)
    else:
        device = torch.device('cpu')
        sampler_train = RandomSampler(train_set)
    # TPU
    device = xm.xla_device()
    sampler_train = torch.utils.data.distributed.DistributedSampler(
        train_set,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)

    data_loader = {
        'train':
        DataLoader(train_set,
                   sampler=sampler_train,
                   batch_size=config.batch_size),
        'valid_train':
        DataLoader(valid_set_train,
                   batch_size=config.batch_size,
                   shuffle=False),
        'valid_valid':
        DataLoader(valid_set_valid,
                   batch_size=config.batch_size,
                   shuffle=False)
    }

    # 2. Build model
    # model = MODEL_MAP[config.model_type](config)
    model = WRAPPED_MODEL
    #load model states.
    # if config.trained_weight:
    #     model.load_state_dict(torch.load(config.trained_weight))
    model.to(device)
    if torch.cuda.is_available():
        model = model
        # model = torch.nn.parallel.DistributedDataParallel(
        #     model, find_unused_parameters=True)

    # 3. Train
    trainer = Trainer(model=model,
                      data_loader=data_loader,
                      device=device,
                      config=config)
    # best_model_state_dict = trainer.train()

    if config.model_type == 'bert':
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay_rate':
            0.01
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay_rate':
            0.0
        }]
        optimizer = AdamW(optimizer_parameters,
                          lr=config.lr,
                          betas=(0.9, 0.999),
                          weight_decay=1e-8,
                          correct_bias=False)
    else:  # rnn
        optimizer = Adam(model.parameters(), lr=config.lr)

    # if config.model_type == 'bert':
    #     scheduler = get_linear_schedule_with_warmup(
    #         optimizer,
    #         num_warmup_steps=config.num_warmup_steps,
    #         num_training_steps=config.num_training_steps)
    # else:  # rnn
    #     scheduler = get_constant_schedule(optimizer)

    criterion = nn.CrossEntropyLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, batch in enumerate(loader):
            # batch = tuple(t.to(self.device) for t in batch)
            output = model(*batch[:-1])  # the last one is label
            loss = criterion(output, batch[-1])
            loss.backward()
            # xm.optimizer_step(optimizer)
            # optimizer.zero_grad()

            tracker.add(FLAGS.batch_size)
            if (x + 1) % config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               config.max_grad_norm)
                # after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。
                xm.optimizer_step(optimizer)
                optimizer.zero_grad()

            if xm.get_ordinal() == 0:
                if x % FLAGS.log_steps == 0:
                    print(
                        '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'
                        .format(xm.get_ordinal(), x, loss.item(),
                                tracker.rate(), tracker.global_rate(),
                                time.asctime()),
                        flush=True)

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        data, pred, target = None, None, None
        tracker = xm.RateTracker()
        for x, batch in enumerate(loader):
            output = model(*batch[:-1])  # the last one is label
            target = batch[-1]
            # pred = output.max(1, keepdim=True)[1]
            # correct += pred.eq(target.view_as(pred)).sum().item()
            for i in range(len(output)):
                logits = output[i]
                pred = int(torch.argmax(logits, dim=-1))
                if pred == target[i]:
                    correct += 1
            total_samples += len(output)

            if xm.get_ordinal() == 0:
                if x % FLAGS.log_steps == 0:
                    print(
                        '[xla:{}]({}) Acc={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'
                        .format(xm.get_ordinal(), x,
                                correct * 1.0 / total_samples, tracker.rate(),
                                tracker.global_rate(), time.asctime()),
                        flush=True)

        accuracy = 100.0 * correct / total_samples
        if xm.get_ordinal() == 0:
            print('[xla:{}] Accuracy={:.2f}%'.format(xm.get_ordinal(),
                                                     accuracy),
                  flush=True)
        return accuracy, data, pred, target

    # Train and eval loops
    accuracy = 0.0
    data, pred, target = None, None, None
    for epoch in range(FLAGS.num_epoch):
        para_loader = pl.ParallelLoader(data_loader['train'], [device])
        train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print("Finished training epoch {}".format(epoch))

        # para_loader = pl.ParallelLoader(data_loader['valid_train'], [device])
        # accuracy_train, data, pred, target = test_loop_fn(para_loader.per_device_loader(device))

        para_loader = pl.ParallelLoader(data_loader['valid_valid'], [device])
        accuracy_valid, data, pred, target = test_loop_fn(
            para_loader.per_device_loader(device))
        xm.master_print("Finished test epoch {}, valid={:.2f}".format(
            epoch, accuracy_valid))

        if FLAGS.metrics_debug:
            xm.master_print(met.metrics_report())

        # 4. Save model
        # if xm.get_ordinal() == 0:
        #     # if epoch==FLAGS.num_epoch-1:
        #     # WRAPPED_MODEL.to('cpu')
        #     torch.save(WRAPPED_MODEL.state_dict(), os.path.join(
        #         config.model_path, config.experiment_name,
        #         config.model_type + '-' + str(epoch + 1) + '.bin'))
        #     xm.master_print('saved model.')
        # WRAPPED_MODEL.to(device)

    return accuracy_valid
Esempio n. 7
0
def train_mnist():
    torch.manual_seed(1)

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=60000 // FLAGS.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        train_dataset = datasets.MNIST(FLAGS.datadir,
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        test_dataset = datasets.MNIST(FLAGS.datadir,
                                      train=False,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    # Scale learning rate to num cores
    lr = FLAGS.lr * xm.xrt_world_size()

    device = xm.xla_device()
    model = MNIST().to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum)
    loss_fn = nn.NLLLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))

        para_loader = pl.ParallelLoader(test_loader, [device])
        accuracy = test_loop_fn(para_loader.per_device_loader(device))
        if FLAGS.metrics_debug:
            print(met.metrics_report())

    return accuracy
Esempio n. 8
0
def train(train_loader, model, optimizer, scheduler, epoch, args, DEVICE):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model = model.train().to(DEVICE)

    loader = pl.ParallelLoader(train_loader,
                               [DEVICE]).per_device_loader(DEVICE)
    # noise2net = Res2Net(epsilon=0.50, hidden_planes=16, batch_size=args.batch_size).train().to(DEVICE)

    end = time.time()
    for i, (images, target) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        bx = images
        by = target

        print("Zero grad")
        optimizer.zero_grad()

        # with torch.no_grad():
        #     if random.random() < 0.5:
        #         batch_size = bx.shape[0]
        #         noise2net.reload_parameters()
        #         noise2net.set_epsilon(random.uniform(args.noisenet_max_eps / 2.0, args.noisenet_max_eps))
        #         bx = bx.reshape((1, batch_size * 3, 224, 224))
        #         bx = noise2net(bx)
        #         bx = bx.reshape((batch_size, 3, 224, 224))

        print("Forward")
        logits = model(bx)

        print("Cross Entropy")
        loss = F.cross_entropy(logits, by)

        # measure accuracy and record loss
        output, target = logits, by
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        print("Backward")
        loss.backward()

        print("Step")
        xm.optimizer_step(optimizer)

        print("Scheduler step")
        scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 and xm.is_master_ordinal():
            progress.display(i)
Esempio n. 9
0
def main(config_file='config/bert_config.json'):
    """Main method for training.

    Args:
        config_file: in config dir
    """
    # 0. Load config and mkdir
    with open(config_file) as fin:
        config = json.load(fin, object_hook=lambda d: SimpleNamespace(**d))
    get_path(os.path.join(config.model_path, config.experiment_name))
    get_path(config.log_path)
    get_path(config.prediction_path)
    # get_path(config.checkpoint_path)
    if config.model_type == 'rnn':  # build vocab for rnn
        build_vocab(file_in=config.all_train_file_path,
                    file_out=os.path.join(config.model_path, 'vocab.txt'))
    # 1. Load data
    data = Data(vocab_file=os.path.join(config.model_path, 'vocab.txt'),
                max_seq_len=config.max_seq_len,
                model_type=config.model_type,
                config=config)

    def load_dataset():
        datasets = data.load_train_and_valid_files(
            train_file=config.train_file_path,
            valid_file=config.valid_file_path)
        return datasets

    if config.serial_load:
        datasets = SERIAL_EXEC.run(load_dataset)
    else:
        datasets = load_dataset()

    train_set, valid_set_train, valid_set_valid, train_exam, valid_exam, train_feat, valid_feat = datasets
    if torch.cuda.is_available():
        device = torch.device('cuda')
        # device = torch.device('cpu')
        # torch.distributed.init_process_group(backend="nccl")
        # sampler_train = DistributedSampler(train_set)
        sampler_train = RandomSampler(train_set)
    else:
        device = torch.device('cpu')
        sampler_train = RandomSampler(train_set)

    # TPU
    device = xm.xla_device()
    sampler_train = torch.utils.data.distributed.DistributedSampler(
        train_set,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)

    data_loader = {
        'train':
        DataLoader(train_set,
                   sampler=sampler_train,
                   batch_size=config.batch_size),
        'valid_train':
        DataLoader(valid_set_train,
                   batch_size=config.batch_size,
                   shuffle=False),
        'valid_valid':
        DataLoader(valid_set_valid,
                   batch_size=config.batch_size,
                   shuffle=False),
        'train_exam':
        train_exam,
        'valid_exam':
        valid_exam,
        'train_feat':
        train_feat,
        'valid_feat':
        valid_feat
    }
    # 2. Build model
    # TPU

    device = xm.xla_device()
    model = WRAPPED_MODEL
    model.to(device)

    if config.model_type == 'bert':
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay_rate':
            0.01
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay_rate':
            0.0
        }]
        optimizer = AdamW(optimizer_parameters,
                          lr=config.lr,
                          betas=(0.9, 0.999),
                          weight_decay=1e-8,
                          correct_bias=False)
    else:  # rnn
        optimizer = Adam(model.parameters(), lr=config.lr)

    criterion = nn.CrossEntropyLoss(reduction='mean',
                                    ignore_index=IGNORE_INDEX)  # 交叉熵损失
    binary_criterion = nn.BCEWithLogitsLoss(reduction='mean')  # 二元损失
    sp_loss_fct = nn.BCEWithLogitsLoss(reduction='none')  # 用于sp,平均值自己算

    #load model states.
    # if config.trained_weight:
    #     model.load_state_dict(torch.load(config.trained_weight))
    # model.to(device)
    # if torch.cuda.is_available():
    #     model = model
    # model = torch.nn.parallel.DistributedDataParallel(
    #     model, find_unused_parameters=True)
    # 3. Train
    trainer = Trainer(model=model,
                      data_loader=data_loader,
                      device=device,
                      config=config)

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, batch in enumerate(loader):
            batch = tuple(t.to(device) for t in batch)
            # loss = self.criterion(logits, batch[-1])
            start_logits, end_logits, type_logits, sp_logits, start_position, end_position = model(
                *batch)
            loss1 = criterion(start_logits, batch[6]) + criterion(
                end_logits, batch[7])  # y1, y2
            loss2 = config.type_lambda * criterion(type_logits,
                                                   batch[8])  # q_type
            # sent_num_in_batch = batch[9].sum()  # is_support
            # sent_num_in_batch = 1.0 + sent_num_in_batch # to avoid devide by zero
            # loss3 = self.sp_loss_fct(sp_logits.view(-1), batch[10].float().view(-1)).sum() * self.config.sp_lambda / sent_num_in_batch
            loss = loss1 + loss2
            loss.backward()
            del batch  #try to save cpu mem.
            tracker.add(FLAGS.batch_size)
            if (x + 1) % config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               config.max_grad_norm)
                # after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。
                xm.optimizer_step(optimizer)
                optimizer.zero_grad()

            if xm.get_ordinal() == 0:
                if x % FLAGS.log_steps == 0:
                    print(
                        '[xla:{}]({}) Loss1={:.5f} Loss2={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'
                        .format(xm.get_ordinal(), x, loss1.item(),
                                loss2.item(), tracker.rate(),
                                tracker.global_rate(), time.asctime()),
                        flush=True)

    # def test_loop_fn(loader):
    #     total_samples = 0
    #     correct = 0
    #     model.eval()
    #     data, pred, target = None, None, None
    #     tracker = xm.RateTracker()
    #     for x, batch in enumerate(loader):
    #         start_logits, end_logits, type_logits, sp_logits, start_position, end_position = model(*batch)
    #         loss1 = self.criterion(start_logits, batch[6]) + self.criterion(end_logits, batch[7])#y1,y2
    #         loss2 = self.config.type_lambda * self.criterion(type_logits, batch[8])#q_type
    #         sent_num_in_batch = batch[9].sum()  # start_mapping
    #         sp_value = self.sp_loss_fct(sp_logits.view(-1), batch[10].float().view(-1)).sum()
    #         if sent_num_in_batch != 0:
    #             loss3 = self.config.sp_lambda * sp_value / sent_num_in_batch
    #         else:
    #             loss3 = self.config.sp_lambda * sp_value * 1e30
    #
    #         loss = loss1 + loss2 + loss3
    #         loss_list = [loss, loss1, loss2, loss3]
    #
    #         for i, l in enumerate(loss_list):
    #             if not isinstance(l, int):
    #                 total_test_loss[i] += l.item()
    #
    #         batchsize = batch[0].size(0)
    #         # ids
    #         answer_dict_ = convert_to_tokens(exam, feats, batch[5], start_position.data.cpu().numpy().tolist(),
    #                                          end_position.data.cpu().numpy().tolist(),
    #                                          np.argmax(type_logits.data.cpu().numpy(), 1))
    #         answer_dict.update(answer_dict_)
    #
    #         predict_support_np = torch.sigmoid(sp_logits).data.cpu().numpy()
    #         for i in range(predict_support_np.shape[0]):
    #             cur_sp_pred = []
    #             cur_id = batch[5][i].item()
    #
    #             cur_sp_logit_pred = []  # for sp logit output
    #             for j in range(predict_support_np.shape[1]):
    #                 if j >= len(exam[cur_id].sent_names):
    #                     break
    #                 if need_sp_logit_file:
    #                     temp_title, temp_id = exam[cur_id].sent_names[j]
    #                     cur_sp_logit_pred.append((temp_title, temp_id, predict_support_np[i, j]))
    #                 if predict_support_np[i, j] > self.config.sp_threshold:
    #                     cur_sp_pred.append(exam[cur_id].sent_names[j])
    #             sp_dict.update({cur_id: cur_sp_pred})
    #
    #     new_answer_dict = {}
    #     for key, value in answer_dict.items():
    #         new_answer_dict[key] = value.replace(" ", "")
    #     prediction = {'answer': new_answer_dict, 'sp': sp_dict}
    #     with open(prediction_file, 'w', encoding='utf8') as f:
    #         json.dump(prediction, f, indent=4, ensure_ascii=False)
    #
    #     for i, l in enumerate(total_test_loss):
    #         print("Test Loss{}: {}".format(i, l / len(dataloader)))
    #
    #     if xm.get_ordinal() == 0:
    #             if x % FLAGS.log_steps == 0:
    #                 print('[xla:{}]({}) Acc={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
    #                     xm.get_ordinal(), x, correct*1.0/total_samples, tracker.rate(),
    #                     tracker.global_rate(), time.asctime()), flush=True)
    #
    #     accuracy = 100.0 * correct / total_samples
    #     if xm.get_ordinal() == 0:
    #         print('[xla:{}] Accuracy={:.2f}%'.format(xm.get_ordinal(), accuracy), flush=True)
    #     return accuracy, data, pred, target

    # Train and eval loops
    accuracy = 0.0
    data, pred, target = None, None, None
    for epoch in range(FLAGS.num_epoch):
        para_loader = pl.ParallelLoader(data_loader['train'], [device])
        train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print("Finished training epoch {}".format(epoch))

        # para_loader = pl.ParallelLoader(data_loader['valid_train'], [device])
        # accuracy_train, data, pred, target = test_loop_fn(para_loader.per_device_loader(device))

        # para_loader = pl.ParallelLoader(data_loader['valid_valid'], [device])
        # accuracy_valid, data, pred, target = test_loop_fn(para_loader.per_device_loader(device))
        # xm.master_print("Finished test epoch {}, valid={:.2f}".format(epoch, accuracy_valid))

    # # 4. Save model
    # torch.save(best_model_state_dict,
    #            os.path.join(config.model_path, 'model.bin'))
    if xm.get_ordinal() == 0:
        results = trainer.valid()
        xm.master_print("Finished training epoch {}".format(results))
Esempio n. 10
0
    def __init__(self, mode, opts):
        print('Initializing model and data source...', end='', file=stderr)
        stderr.flush()
        self.learning_rates = dict(
            zip(opts.learning_rate_steps, opts.learning_rate_rates))
        self.opts = opts

        if mode == 'new':
            torch.manual_seed(opts.random_seed)

            # Initialize data
            dataset = data.Slice(opts)
            dataset.load_data(opts.dat_file)
            opts.training = True
            if opts.global_model == 'autoencoder':
                model = ae.AutoEncoder(opts, dataset)
            elif opts.global_model == 'mfcc_inverter':
                model = mi.MfccInverter(opts, dataset)

            model.post_init(dataset)
            dataset.post_init(model)
            optim = torch.optim.Adam(params=model.parameters(),
                                     lr=self.learning_rates[0])
            self.state = checkpoint.State(0, model, dataset, optim)
            self.start_step = self.state.step

        else:
            self.state = checkpoint.State()
            self.state.load(opts.ckpt_file, opts.dat_file)
            self.start_step = self.state.step
            # print('Restored model, data, and optim from {}'.format(opts.ckpt_file), file=stderr)
            #print('Data state: {}'.format(state.data), file=stderr)
            #print('Model state: {}'.format(state.model.checksum()))
            #print('Optim state: {}'.format(state.optim_checksum()))
            stderr.flush()

        if self.state.model.bn_type == 'vae':
            self.anneal_schedule = dict(
                zip(opts.bn_anneal_weight_steps, opts.bn_anneal_weight_vals))

        self.ckpt_path = util.CheckpointPath(self.opts.ckpt_template)
        self.quant = None
        self.target = None
        self.softmax = torch.nn.Softmax(1)  # input to this is (B, Q, N)

        if self.opts.hwtype == 'GPU':
            self.device = torch.device('cuda')
            self.data_loader = self.state.data_loader
            self.data_loader.set_target_device(self.device)
            self.optim_step_fn = (lambda: self.state.optim.step(self.loss_fn))
            self.data_iter = GPULoaderIter(iter(self.data_loader))
        else:
            import torch_xla.core.xla_model as xm
            import torch_xla.distributed.parallel_loader as pl
            self.device = xm.xla_device()
            self.data_loader = pl.ParallelLoader(self.state.data_loader,
                                                 [self.device])
            self.data_iter = TPULoaderIter(self.data_loader, self.device)
            self.optim_step_fn = (lambda: xm.optimizer_step(
                self.state.optim, optimizer_args={'closure': self.loss_fn}))

        self.state.init_torch_generator()
        print('Done.', file=stderr)
        stderr.flush()
Esempio n. 11
0
    def train(self, model_path=None, dev_objective=None):
        """
        Main training entry point.

        The training logic is directly borrowed from transformers.Trainer (version 3.0.2).
        Add early stopping.
        """
        self.best_dir = None
        self.objective = -float("inf")
        self.dev_objective = dev_objective if dev_objective is not None else default_dev_objective

        # Data loading.
        train_dataloader = self.get_train_dataloader()
        num_update_steps_per_epoch = len(
            train_dataloader) // self.args.gradient_accumulation_steps
        if num_update_steps_per_epoch == 0:
            num_update_steps_per_epoch = 1
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                self.args.max_steps % num_update_steps_per_epoch > 0)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        self.create_optimizer_and_scheduler(num_training_steps=t_total)
        optimizer = self.optimizer
        scheduler = self.lr_scheduler

        # Check if saved optimizer or scheduler states exist
        if (model_path is not None
                and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
                and os.path.isfile(os.path.join(model_path, "scheduler.pt"))):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"),
                           map_location=self.args.device))
            scheduler.load_state_dict(
                torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model

        if self.args.fp16 and _use_apex:
            if not transformers.is_apex_available():
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=self.args.fp16_opt_level)

        # Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        # Train
        if transformers.is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size(
            )
        else:
            total_train_batch_size = (self.args.train_batch_size *
                                      self.args.gradient_accumulation_steps *
                                      (torch.distributed.get_world_size()
                                       if self.args.local_rank != -1 else 1))
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d",
                    self.args.per_device_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            self.global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = torch.tensor(0.0).to(self.args.device)
        logging_loss_scalar = 0.0
        model.zero_grad()
        train_iterator = trange(epochs_trained,
                                int(num_train_epochs),
                                desc="Epoch",
                                disable=not self.is_local_master())
        for epoch in train_iterator:
            if isinstance(train_dataloader, DataLoader) and isinstance(
                    train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if transformers.is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(
                    train_dataloader,
                    [self.args.device]).per_device_loader(self.args.device)
                epoch_iterator = tqdm(parallel_loader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())
            else:
                epoch_iterator = tqdm(train_dataloader,
                                      desc="Iteration",
                                      disable=True)

            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self.training_step(model, inputs)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(optimizer)
                        norm = torch.nn.utils.clip_grad_norm_(
                            model.parameters(), self.args.max_grad_norm)
                    elif self.args.fp16:
                        norm = torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.args.max_grad_norm)
                    else:
                        norm = torch.nn.utils.clip_grad_norm_(
                            model.parameters(), self.args.max_grad_norm)

                    if transformers.is_torch_tpu_available():
                        xm.optimizer_step(optimizer)
                    elif self.args.fp16 and _use_native_amp:
                        self.scaler.step(optimizer)
                        self.scaler.update()
                    else:
                        optimizer.step()

                    scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0
                            and self.global_step % self.args.logging_steps
                            == 0) or (self.global_step == 1
                                      and self.args.logging_first_step):
                        logs = {}
                        tr_loss_scalar = tr_loss.item()
                        logs["loss"] = (tr_loss_scalar - logging_loss_scalar
                                        ) / self.args.logging_steps
                        logs["norm"] = norm.item()
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >=
                            version.parse("1.4") else scheduler.get_lr()[0])
                        logging_loss_scalar = tr_loss_scalar

                        self.log(logs)

                    # ----------------------------------------------------------------------
                    # BEGIN CHANGES.
                    # ----------------------------------------------------------------------

                    metrics = None
                    if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
                        output = self.evaluate()
                        metrics = output.metrics
                        objective = self.dev_objective(metrics)
                        if objective > self.objective:
                            logger.info(
                                "Best dev result: {}".format(objective))
                            self.objective = objective
                            self.save_model(self.args.output_dir)

                    # ----------------------------------------------------------------------
                    # END CHANGES.
                    # ----------------------------------------------------------------------

                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                train_iterator.close()
                break
            if self.args.tpu_metrics_debug or self.args.debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(self.global_step,
                           tr_loss / self.global_step), self.objective
Esempio n. 12
0
def train(rank, args):
    print('enter train @ %s'%(rank), flush=True)
    args.rank = rank
    args.split = ''
    torch.manual_seed(42)
    save_fn = os.path.join(args.save_dir, 'checkpoint_final.pt')

    tokenizer = get_tokenizer(args)
    args.vocab_size = tokenizer._tokenizer.get_vocab_size() if not args.vocab_size else args.vocab_size
    
    train_dataset = get_dataset(args)
    
    batched_already = hasattr(train_dataset, '__getbatch__')

    if args.total_num_updates < 100:
        args.total_num_updates = len(train_dataset) * args.total_num_updates

    if args.warmup_updates < 1:
        args.warmup_updates = int(args.total_num_updates * args.warmup_updates)
    else:
        args.warmup_updates = int(args.warmup_updates)
    
    train_sampler = None
    if args.gpus:
        dist.init_process_group(
            'nccl', 
            rank=rank, 
            world_size=args.world_size
        )
        if args.gpus > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=args.gpus,
                rank=rank,
                shuffle=args.shuffle)

    else:
        rank = xm.get_ordinal()
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=rank,
                shuffle=args.shuffle)


    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size if not batched_already else None,
        sampler=train_sampler,
        pin_memory=True,
        shuffle=False,
        num_workers=args.num_workers)
        

    eval_loaders = []
    if args.eval_dir:
        for split in args.splits.split(','):
            split = split.strip()
            eval_sampler = None
            if args.gpus:
                if args.gpus > 1:
                    eval_sampler = torch.utils.data.distributed.DistributedSampler(
                        train_dataset,
                        num_replicas=args.gpus,
                        rank=rank,
                        shuffle=False)

            else:
                rank = xm.get_ordinal()
                if xm.xrt_world_size() > 1:
                    eval_sampler = torch.utils.data.distributed.DistributedSampler(
                        train_dataset,
                        num_replicas=xm.xrt_world_size(),
                        rank=rank,
                        shuffle=False)

            args.split = split
            eval_dataset = get_eval_dataset(args)
            eval_loader = torch.utils.data.DataLoader(
                eval_dataset,
                batch_size=args.batch_size if not batched_already else None,
                sampler=eval_sampler,
                pin_memory=True,
                shuffle=False,
                num_workers=args.num_workers)
            eval_loaders.append(eval_loader)

    if args.gpus:
        assert apex_enabled
        torch.cuda.set_device(rank)


        ##########################
        ##
        ##  Model Creation
        ##
        ##########################
        model = get_model(args, tokenizer)

        model.cuda(rank)

        device = torch.device('cuda:'+str(rank))

        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = apex.optimizers.FusedAdam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay,
                                 special_layer_wise_lr=args.special_layer_wise_lr,
                                 log = rank == 0,
                                 ),  

                                 # use this function to set extra optimizer arguments, 
                                 # see model_get_parameters
            betas=(0.9, 0.999), 
            eps=1e-6,
            lr=args.lr, 
            weight_decay=args.weight_decay
        )




        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        model = DDP(model)
        batches = train_loader

    else:
        assert tpu_enabled
        device = xm.xla_device()


        ##########################
        ##
        ##  Model Creation
        ##
        ##########################
        
        model = get_model(args, tokenizer)


        ##########################
        ##
        ##  For shared parameters, TPU requires modules to be tied after .to(device)
        ##  So we first find the shared parameters first
        ##
        ##########################

        shared_parameters = {e[0]: e[1:] for e in _catalog_shared_params(model)}

        model.to(device)
        
        do_share_parameters_again(model, shared_parameters, log = rank == 0)



        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = optim.Adam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay
                                 ),  

                                 # use this function to set extra optimizer arguments, 
                                 # see model_get_parameters
            lr=args.lr,
            weight_decay=args.weight_decay
        )


        writer = None
        if xm.is_master_ordinal():
            writer = test_utils.get_summary_writer(args.save_dir)
                
        xm.rendezvous("load_checkpoint")  # wait for all workers
        xm.mark_step()

        # tracker = xm.RateTracker()
        
        
        
    if args.restore_file:
        states = torch.load(args.restore_file, map_location=device)
        for k, v in list(states.items()):
            if k.startswith('module.'):
                del states[k]
                k = k[7:]
                states[k] = v
            if k.endswith('position_ids'):
                del states[k]
                states[k[:-12] + 'position_embeddings'] = v
                
        if args.gpus:
            states = {"module.%s"%k : v for k, v in states.items()}
        try:
            model.load_state_dict(states)
        except Exception as err:
            import traceback
            if rank == 0:
                traceback.print_exc()
            model.load_state_dict(states, strict=False)
            
        
    if rank == 0:
        if not os.path.exists(os.path.dirname(save_fn)):
            try:
                os.makedirs(os.path.dirname(save_fn))
            except OSError as exc: # Guard against race condition
                if exc.errno != errno.EEXIST:
                    raise
        if args.gpus:
            torch.save(model.state_dict(), save_fn )
        else:
            xm.save(model.state_dict(), save_fn )
        
    model.train()

    if args.anomaly_detection and rank == 0:
        torch.set_anomaly_enabled(True)

    ##########################
    ##
    ##  Init LR Scheduler
    ##
    ##########################
    
    if not batched_already:
        args.total_num_updates = args.total_num_updates // args.batch_size
        args.warmup_updates = args.total_num_updates // args.batch_size
        
        
    args.total_num_updates = args.total_num_updates // args.world_size
    args.warmup_updates = args.total_num_updates // args.world_size

    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=args.warmup_updates, 
        num_training_steps=args.total_num_updates, 
    )

    step_i = 0

    err = None
    tb = None
    #tb = SummaryWriter()
    try:
        if rank == 0:
            pbar = tqdm(total=args.total_num_updates, file=sys.stdout)
        while step_i < args.total_num_updates:
            if not args.gpus:
                batches = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
                
            n_samples = len(batches)
                
            for sample in batches:
                step_i += 1
                if step_i > args.total_num_updates:
                    break

                report_step = step_i % args.log_interval == 0

                while True: # the loop only for apex Gradient Overflow
                    optimizer.zero_grad()

                    total_loss, log = get_loss(
                        model, 
                        sample, 
                        args=args, 
                        device=device, 
                        gpus=args.gpus, 
                        report=report_step
                    )

                    if args.gpus:
                        default_optimizer_step = optimizer.step

                        with amp.scale_loss(total_loss, optimizer) as scaled_loss:
                            scaled_loss.backward()

                        # If Amp detects an overflow, it patches optimizer.step.  In other words, if optimizer.step
                        # was left unpatched, there was no overflow, and we don't need to replay.
                        if optimizer.step is default_optimizer_step:
                            optimizer.step()
                            break

                        optimizer.step() # If an overflow was detected, "optimizer.step" is the patched call, which does 
                                         # nothing but restore optimizer.step to default_optimizer_step.
                        if rank == 0:
                            print("Overflowed, reducing loss scale and replaying batch.", flush=True)
                        
                    else:
                        total_loss.backward()
                        xm.optimizer_step(optimizer)
                        xm.mark_step()

                        break



                scheduler.step()

                if report_step:
                    if 'loss' not in log:
                        log['loss'] = total_loss

                    # tb.add_scalar("Loss", total_loss, step_i)

                    for k, v in log.items():
                        try:
                            dist.all_reduce(v, op=dist.reduce_op.SUM)
                            log[k] = float(v)
                        except Exception as e:
                            print(v, e)
                            pass
                        
                    if args.gpus:
                        if rank == 0:
                            pbar.set_description(format_log(log, log_formatter, tb, step_i))
                    else:
                        xm.add_step_closure(_train_update, args=(log, log_formatter, tb, step_i))

                    if args.report_metrics:
                        xm.master_print(met.metrics_report())

                
                if rank == 0:
                    pbar.update(1)

        if rank == 0:
            pbar.close()
        if eval_loaders:
            model.half()
            model.eval()
            model.cuda()
            for k, v in model.named_parameters():
                v.requires_grad =False

                
            for split, eval_loader in zip(args.splits.split(','), eval_loaders):
                batches = eval_loader
                if rank == 0:
                    eval_length = len(batches)
                    if not batched_already:
                        eval_length = eval_length // args.batch_size

                    eval_length = eval_length // args.world_size

                    pbar = tqdm(total=eval_length, file=sys.stdout)
                
                if not args.gpus:
                    batches = pl.ParallelLoader(eval_loader, [device]).per_device_loader(device)
                with torch.no_grad():
                    record = OrderedDict()

                    for sample in batches:
                        evaluate(
                            model, 
                            sample, 
                            args=args, 
                            device=device, 
                            record=record,
                            gpus=args.gpus, 
                            report=False
                        )
                        if rank == 0:
                            pbar.update(1)

                    for k, v in record.items():
                        try:
                            def handle_reduce(v):
                                if len(v.shape) == 0:
                                    dist.all_reduce(v, op=dist.reduce_op.SUM)
                                else:
                                    L = [torch.ones_like(v) for _ in range(dist.get_world_size())]
                                    dist.all_gather(L, v)
                                    v = torch.car(L, dim=0)
                                return v
                            if isinstance(v, list):
                                v = [handle_reduce(e) for e in v]
                            else:
                                v = handle_reduce(v)
                            record[k] = float(v)
                        except Exception as e:
                            pass

                    post_evaluate(record, args=args)

                import json

                if rank == 0:
                    print('',flush=True)
                    print('Test result for %s'%split, flush=True)
                    print(json.dumps(record, indent=2),flush=True)
                    print('',flush=True)


    except Exception as _err:
        err = _err
    finally:
        folder = os.path.split(os.path.abspath(save_fn))[0]
        os.makedirs(folder, exist_ok=True)
        if rank == 0:
            print("Saving to %s"%save_fn)
            if args.gpus:
                torch.save(model.state_dict(), save_fn )
                if err:
                    raise err
            else:
                xm.save(model.state_dict(), save_fn )
                if err:
                    raise err
            print("Saved to %s"%save_fn)
Esempio n. 13
0
 def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> ParallelLoader:
     device = xm.xla_device()
     dataloader = xla_pl.ParallelLoader(dataloader, [device])
     dataloader = dataloader.per_device_loader(device)
     return dataloader
Esempio n. 14
0
 def get_tpu_dataloader(self, data_loader):
     return pl.ParallelLoader(data_loader,
                              [self.device]).per_device_loader(self.device)
Esempio n. 15
0
    def _prediction_loop(
            self,
            dataloader: DataLoader,
            description: str,
            prediction_loss_only: Optional[bool] = None) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`.

        Works both with or without labels.
        """

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = self.model
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        batch_size = dataloader.batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
        eval_losses: List[float] = []
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
        model.eval()

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(
                dataloader,
                [self.args.device]).per_device_loader(self.args.device)

        for inputs in tqdm(dataloader, desc=description):
            has_labels = any(
                inputs.get(k) is not None
                for k in ["labels", "lm_labels", "masked_lm_labels"])

            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)

            with torch.no_grad():
                outputs = model(**inputs)
                if has_labels:
                    step_eval_loss, logits = outputs[:2]
                    eval_losses += [step_eval_loss.mean().item()]
                else:
                    logits = outputs[0]

            if not prediction_loss_only:
                if preds is None:
                    preds = logits.detach()
                else:
                    preds = torch.cat((preds, logits.detach()), dim=0)
                if inputs.get("labels") is not None:
                    if label_ids is None:
                        label_ids = inputs["labels"].detach()
                    else:
                        label_ids = torch.cat(
                            (label_ids, inputs["labels"].detach()), dim=0)

        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = self.distributed_concat(
                    preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = self.distributed_concat(
                    label_ids,
                    num_total_examples=self.num_examples(dataloader))
        elif is_torch_tpu_available():
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids,
                                           torch.cat)

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(
                EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds,
                                label_ids=label_ids,
                                metrics=metrics)
Esempio n. 16
0
    def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
        """
        Main training entry point.
        Args:
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
            trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
                The trial run or the hyperparameter dictionary for hyperparameter search.
        """
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

        # Model re-init
        if self.model_init is not None:
            # Seed must be set before instantiating the model when using model_init.
            set_seed(self.args.seed)
            model = self.model_init()
            self.model = model.to(self.args.device)

            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None

        # Data loader and number of training steps
        train_dataloader = self.get_train_dataloader()
        num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                self.args.max_steps % num_update_steps_per_epoch > 0
            )
        else:
            t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs
            self.args.max_steps = t_total

        self.create_optimizer_and_scheduler(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            self.optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16 and _use_apex:
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})

        # Train!
        if is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])

                epochs_trained = self.global_step // num_update_steps_per_epoch
                steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", self.global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss_sum = 0.0
        loss_sum = defaultdict(float)
        best = {self.best_metric: None}
        model.zero_grad()
        disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
        train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
        for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
                epoch_iterator = parallel_loader
            else:
                epoch_iterator = train_dataloader

            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

            epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    epoch_pbar.update(1)
                    continue

                model.train()
                inputs = self._prepare_inputs(inputs)

                inputs["output_attentions"] = self.length_drop_args.length_config is not None

                layer_config = sample_layer_configuration(
                    model.config.num_hidden_layers,
                    layer_dropout_prob=self.length_drop_args.layer_dropout_prob,
                    layer_dropout=0,
                )
                inputs["layer_config"] = layer_config

                inputs["length_config"] = self.length_drop_args.length_config

                outputs = model(**inputs)
                # Save past state if it exists
                if self.args.past_index >= 0:
                    self._past = outputs[self.args.past_index]
                task_loss = self.div_loss(outputs[0])
                if self.length_drop_args.length_adaptive:
                    loss_sum["full"] += task_loss.item()
                loss = task_loss
                if self.length_drop_args.length_adaptive:
                    loss = loss / (self.length_drop_args.num_sandwich + 2)

                tr_loss_sum += loss.item()
                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # inplace distillation
                if self.length_drop_args.length_adaptive:
                    logits = outputs[1].detach()

                    for i in range(self.length_drop_args.num_sandwich + 1):
                        inputs["output_attentions"] = True

                        layer_config = sample_layer_configuration(
                            model.config.num_hidden_layers,
                            layer_dropout_prob=self.length_drop_args.layer_dropout_prob,
                            layer_dropout=(self.length_drop_args.layer_dropout_bound if i == 0 else None),
                            layer_dropout_bound=self.length_drop_args.layer_dropout_bound,
                        )
                        inputs["layer_config"] = layer_config

                        length_config = sample_length_configuration(
                            self.args.max_seq_length,
                            model.config.num_hidden_layers,
                            layer_config,
                            length_drop_ratio=(self.length_drop_args.length_drop_ratio_bound if i == 0 else None),
                            length_drop_ratio_bound=self.length_drop_args.length_drop_ratio_bound,
                        )
                        inputs["length_config"] = length_config

                        outputs_sub = model(**inputs)
                        task_loss_sub = self.div_loss(outputs_sub[0])
                        if i == 0:
                            loss_sum["smallest"] += task_loss_sub.item()
                            loss_sum["sub"] += 0
                        else:
                            loss_sum["sub"] += task_loss_sub.item() / self.length_drop_args.num_sandwich

                        logits_sub = outputs_sub[1]
                        loss_fct = KLDivLoss(reduction="batchmean")
                        kl_loss = loss_fct(F.log_softmax(logits, -1), F.softmax(logits_sub, -1))
                        loss = self.div_loss(kl_loss)
                        loss_sum["kl"] += loss.item() / (self.length_drop_args.num_sandwich + 1)
                        loss = loss / (self.length_drop_args.num_sandwich + 2)

                        tr_loss_sum += loss.item()
                        if self.args.fp16 and _use_native_amp:
                            self.scaler.scale(loss).backward()
                        elif self.args.fp16 and _use_apex:
                            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                                scaled_loss.backward()
                        else:
                            loss.backward()

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    (step + 1) == len(epoch_iterator) <= self.args.gradient_accumulation_steps
                ):
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                    if is_torch_tpu_available():
                        xm.optimizer_step(self.optimizer)
                    elif self.args.fp16 and _use_native_amp:
                        self.scaler.step(self.optimizer)
                        self.scaler.update()
                    else:
                        self.optimizer.step()

                    self.lr_scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        # backward compatibility for pytorch schedulers
                        lr = (
                            self.lr_scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >= version.parse("1.4")
                            else self.lr_scheduler.get_lr()[0]
                        )
                        loss = tr_loss_sum / self.args.logging_steps
                        tr_loss_sum = 0.0
                        logs = {"lr": lr, "loss": loss}
                        log_str = f"[{self.global_step:5d}] lr {lr:g} | loss {loss:2.3f}"

                        for key, value in loss_sum.items():
                            value /= self.args.logging_steps
                            loss_sum[key] = 0.0
                            logs[f"{key}_loss"] = value
                            log_str += f" | {key}_loss {value:2.3f}"

                        self.log(logs, "train")
                        logger.info(log_str)

                    '''
                    if (
                        self.args.evaluation_strategy == EvaluationStrategy.STEPS
                        and self.global_step % self.args.eval_steps == 0
                    ):
                        results = self.evaluate()
                        self._report_to_hp_search(trial, epoch, results)
                    '''

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert (
                                model.module is self.model
                            ), f"Module {model.module} should be a reference to self.model"
                        else:
                            assert model is self.model, f"Model {model} should be a reference to self.model"

                        if self.args.evaluate_during_training:
                            results = self.evaluate()
                            results = {k[5:]: v for k, v in results.items() if k.startswith("eval_")}
                            self.log(results, "dev")
                            msg = " | ".join([f"{k} {v:.3f}" for k, v in results.items()])
                            logger.info(f"  [{self.global_step:5d}] {msg}")

                        # Save model checkpoint
                        if self.args.save_only_best:
                            output_dirs = []
                        else:
                            checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
                            if self.hp_search_backend is not None and trial is not None:
                                run_id = (
                                    trial.number
                                    if self.hp_search_backend == HPSearchBackend.OPTUNA
                                    else tune.get_trial_id()
                                )
                                checkpoint_folder += f"-run-{run_id}"
                            output_dirs = [os.path.join(self.args.output_dir, checkpoint_folder)]
                            
                        if self.args.evaluate_during_training:
                            if best[self.best_metric] is None or results[self.best_metric] > best[self.best_metric]:
                                logger.info("Congratulations, best model so far!")
                                output_dirs.append(os.path.join(self.args.output_dir, "checkpoint-best"))
                                best = results

                        for output_dir in output_dirs:
                            self.save_model(output_dir)

                            if self.is_world_master() and self.tokenizer is not None:
                                self.tokenizer.save_pretrained(output_dir)

                            if self.is_world_process_zero():
                                self._rotate_checkpoints(use_mtime=True)

                            '''
                            if is_torch_tpu_available():
                                xm.rendezvous("saving_optimizer_states")
                                xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                            elif self.is_world_process_zero():
                                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                            '''

                epoch_pbar.update(1)
                if 0 < self.args.max_steps <= self.global_step:
                    break
            epoch_pbar.close()
            train_pbar.update(1)

            '''
            if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
                results = self.evaluate()
                self._report_to_hp_search(trial, epoch, results)
            '''

            if self.args.tpu_metrics_debug or self.args.debug:
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
            if 0 < self.args.max_steps <= self.global_step:
                break

        train_pbar.close()
        if self.tb_writer:
            self.tb_writer.close()
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        return self.global_step, best
Esempio n. 17
0
    def train(self, data_loader):
        losses = AverageMeter()
        self.model.train()
        print_idx = int(len(data_loader) * self.tpu_print / 100)
        if self.accumulation_steps > 1:
            self.optimizer.zero_grad()
        if self.use_tpu:
            para_loader = pl.ParallelLoader(data_loader, [self.device])
            tk0 = para_loader.per_device_loader(self.device)
        else:
            tk0 = tqdm(data_loader, total=len(data_loader))

        for b_idx, data in enumerate(tk0):
            if self.accumulation_steps == 1 and b_idx == 0:
                self.optimizer.zero_grad()

            if self.model_fn is None:
                for key, value in data.items():
                    data[key] = value.to(self.device)
                _, loss = self.model(**data)
            else:
                loss = self.model_fn(data, self.device, self.model)

            if not self.use_tpu:
                with torch.set_grad_enabled(True):
                    if self.use_mean_loss:
                        loss = loss.mean()
                    if self.fp16:
                        with amp.scale_loss(loss,
                                            self.optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    if (b_idx + 1) % self.accumulation_steps == 0:
                        self.optimizer.step()
                        if self.scheduler is not None:
                            self.scheduler.step()
                        if b_idx > 0:
                            self.optimizer.zero_grad()
            else:
                loss.backward()
                xm.optimizer_step(self.optimizer)
                if self.scheduler is not None:
                    self.scheduler.step()
                if b_idx > 0:
                    self.optimizer.zero_grad()
            if self.use_tpu:
                reduced_loss = xm.mesh_reduce("loss_reduce", loss, reduce_fn)
                losses.update(reduced_loss.item(), data_loader.batch_size)
            else:
                losses.update(loss.item(), data_loader.batch_size)

            if not self.use_tpu:
                tk0.set_postfix(loss=losses.avg)
            else:
                if b_idx % print_idx == 0 or b_idx == len(data_loader):
                    xm.master_print(
                        f"{datetime.datetime.now()}: Batch {b_idx} / {len(data_loader)}, loss={losses.avg}"
                    )
        if not self.use_tpu:
            tk0.close()
        return losses.avg
Esempio n. 18
0
    def train_model(index):
        device = xm.xla_device()
        torch.manual_seed(0)

        if not os.path.exists(tokenized_data_path):
            os.mkdir(tokenized_data_path)
        if not args.pretrained_model:
            model = transformers.modeling_gpt2.GPT2LMHeadModel(
                config=model_config)
        else:
            model = transformers.modeling_gpt2.GPT2LMHeadModel(
                config=model_config)
            model.load_state_dict(torch.load(output_dir + 'final_model'))
        model.train()
        model.to(device)
        multi_gpu = False
        full_len = 0
        # print('calculating total steps')
        # for i in tqdm(range(num_pieces)):
        #     with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f:
        #         full_len += len([int(item) for item in f.read().strip().split()])
        # total_steps = int(full_len / stride * epochs / batch_size / gradient_accumulation)
        # print('total steps = {}'.format(total_steps))

        optimizer = transformers.AdamW(model.parameters(),
                                       lr=lr,
                                       correct_bias=True)
        # scheduler = transformers.WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps)
        # if fp16:
        #     try:
        #         from apex import amp
        #     except ImportError:
        #         raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        #     model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level)

        # if torch.cuda.device_count() > 1:
        #     print("Let's use", torch.cuda.device_count(), "GPUs!")
        #     model = DataParallel(model)
        #     multi_gpu = True
        if xm.is_master_ordinal():
            print('starting training')

        doc_size = 10
        raw_data_batch_len = len(raw_data_files) // doc_size
        for epoch in range(epochs):
            if xm.is_master_ordinal():
                print('epoch {}'.format(epoch + 1))
                now = datetime.now()
                print('time: {}'.format(now))
            for batch_len in range(raw_data_batch_len):
                train_dataset = TextDataset(
                    raw_data_files[batch_len * doc_size:(batch_len + 1) *
                                   doc_size], tokenized_data_path,
                    full_tokenizer, n_ctx)

                train_sampler = torch.utils.data.distributed.DistributedSampler(
                    train_dataset,
                    num_replicas=xm.xrt_world_size(),
                    rank=xm.get_ordinal(),
                    shuffle=True)

                # Creates dataloaders, which load data in batches
                # Note: test loader is not shuffled or sampled
                train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=batch_size,
                    sampler=train_sampler,
                    num_workers=8,
                    drop_last=True)

                # tokens = get_tokenization(raw_data_file, tokenized_data_path, full_tokenizer)
                # if tokens is None:
                #     continue
                # start_point = 0
                # samples = []
                # while start_point < len(tokens) - n_ctx:
                #     samples.append(tokens[start_point: start_point + n_ctx])
                #     start_point += stride
                # if start_point < len(tokens):
                #     samples.append(tokens[len(tokens) - n_ctx:])
                # random.shuffle(samples)
                para_train_loader = pl.ParallelLoader(
                    train_loader, [device]).per_device_loader(device)
                running_loss = 0
                for step, batch_inputs in enumerate(para_train_loader):

                    # for step in range(len(samples) // batch_size):

                    #  prepare data
                    # batch = samples[step * batch_size: (step + 1) * batch_size]
                    # batch_labels = []
                    # batch_inputs = []
                    # for ids in batch:
                    #     int_ids_for_labels = [int(x) for x in ids]
                    #     int_ids_for_inputs = [int(x) for x in ids]
                    #     batch_labels.append(int_ids_for_labels)
                    #     batch_inputs.append(int_ids_for_inputs)
                    # print(batch_inputs)
                    batch_inputs = batch_inputs.to(device)

                    # print(batch_labels.size(), batch_inputs.size())
                    #  forward pass
                    outputs = model.forward(input_ids=batch_inputs,
                                            labels=batch_inputs)
                    loss, logits = outputs[:2]

                    #  get loss
                    # if multi_gpu:
                    #     loss = loss.mean()
                    # if gradient_accumulation > 1:
                    #     loss = loss / gradient_accumulation

                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   max_grad_norm)

                    xm.optimizer_step(optimizer)

                    # if (step + 1) % gradient_accumulation == 0:
                    #     running_loss += loss.item()
                    # optimizer.step()
                    # xm.optimizer_step(optimizer)
                    # optimizer.zero_grad()
                    # scheduler.step()
                    if xm.is_master_ordinal():
                        if (step + 1) % log_step == 0:
                            print(
                                'now time: {}:{}. Step {}/{} of pice {}/{} epoch {}, loss {}'
                                .format(datetime.now().hour,
                                        datetime.now().minute, (step + 1),
                                        len(para_train_loader), batch_len + 1,
                                        raw_data_batch_len, epoch + 1,
                                        running_loss / log_step))
                            running_loss = 0
                        else:
                            running_loss += loss.item()
                xm.save(model.state_dict(), output_dir + 'final_model')

                if xm.is_master_ordinal():
                    gc.collect()
    def _evaluate(self,
                  model: LightningModule,
                  dataloaders,
                  max_batches: int,
                  test_mode: bool = False):
        """Run evaluation code.

        Args:
            model: PT model
            dataloaders: list of PT dataloaders
            max_batches: Scalar
            test_mode:
        """
        # enable eval mode
        model.zero_grad()
        model.eval()

        # copy properties for forward overrides
        self.copy_trainer_model_properties(model)

        # disable gradients to save memory
        torch.set_grad_enabled(False)

        # bookkeeping
        outputs = []

        # run validation
        for dataloader_idx, dataloader in enumerate(dataloaders):
            dl_outputs = []

            # on TPU we have to wrap it under the ParallelLoader
            if self.use_tpu:
                device = xm.xla_device(self.tpu_id)
                dataloader = xla_pl.ParallelLoader(dataloader, [device])
                dataloader = dataloader.per_device_loader(device)

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when on fast_dev_run (sets max_batch=1)
                if batch_idx >= max_batches:
                    break

                # callbacks
                if test_mode:
                    self.on_test_batch_start()
                else:
                    self.on_validation_batch_start()

                # -----------------
                # RUN EVALUATION STEP
                # -----------------
                if self.use_amp and self.use_native_amp:
                    with torch.cuda.amp.autocast():
                        output = self.evaluation_forward(
                            model, batch, batch_idx, dataloader_idx, test_mode)
                else:
                    output = self.evaluation_forward(model, batch, batch_idx,
                                                     dataloader_idx, test_mode)

                # on dp / ddp2 might still want to do something with the batch parts
                if test_mode:
                    if self.is_overridden('test_step_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('test_step_end'):
                            output = model_ref.test_step_end(output)
                    self.on_test_batch_end()
                else:
                    if self.is_overridden('validation_step_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('validation_step_end'):
                            output = model_ref.validation_step_end(output)
                    self.on_validation_batch_end()

                # track outputs for collation
                dl_outputs.append(output)

            outputs.append(dl_outputs)

        eval_results = {}

        # with a single dataloader don't pass an array
        if len(dataloaders) == 1:
            outputs = outputs[0]

        # give model a chance to do something with the outputs (and method defined)
        if isinstance(
                model,
            (LightningDistributedDataParallel, LightningDataParallel)):
            model = model.module

        if test_mode:
            if self.is_overridden('test_end', model=model):
                # TODO: remove in v1.0.0
                eval_results = model.test_end(outputs)
                rank_zero_warn(
                    'Method `test_end` was deprecated in v0.7 and will be removed v1.0.'
                    ' Use `test_epoch_end` instead.', DeprecationWarning)

            elif self.is_overridden('test_epoch_end', model=model):
                eval_results = model.test_epoch_end(outputs)

        else:
            if self.is_overridden('validation_end', model=model):
                # TODO: remove in v1.0.0
                eval_results = model.validation_end(outputs)
                rank_zero_warn(
                    'Method `validation_end` was deprecated in v0.7 and will be removed v1.0.'
                    ' Use `validation_epoch_end` instead.', DeprecationWarning)

            elif self.is_overridden('validation_epoch_end', model=model):
                eval_results = model.validation_epoch_end(outputs)

        # enable train mode again
        model.train()

        # enable gradients to save memory
        torch.set_grad_enabled(True)

        return eval_results
Esempio n. 20
0
def train(args, model, tokenizer):
    """ Train the model """
    if xm.is_master_ordinal():
        tb_writer = SummaryWriterP(args.output_dir)

    def summary_write(*args, **kwargs):
        if xm.is_master_ordinal():
            tb_writer.add_scalar(*args, **kwargs)

    args.train_batch_size = args.per_gpu_train_batch_size  #* max(1, args.n_gpu)

    train_dataloader = build_dataloader(args, tokenizer)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if p.requires_grad and not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if p.requires_grad and any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    # Scale learning rate to num cores
    #args.learning_rate = args.learning_rate * xm.xrt_world_size()
    if args.sgd:
        optimizer = SGD(optimizer_grouped_parameters, lr=args.learning_rate)
    else:
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    warmup_steps = args.warmup_samples // (args.train_batch_size *
                                           xm.xrt_world_size())
    if args.lr_decay:
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=warmup_steps,
                                         t_total=t_total)
    elif args.lr_cosine:
        scheduler = WarmupCosineWithHardRestartsSchedule(
            optimizer,
            warmup_steps=warmup_steps,
            t_total=t_total,
            cycles=args.num_train_epochs)
    else:
        scheduler = WarmupZeroSchedule(optimizer, warmup_steps=warmup_steps)

    # Train!
    tracker = xm.RateTracker()
    log_info("***** Running training *****")
    log_info("  Num Epochs = %d", args.num_train_epochs)
    log_info("  Instantaneous batch size per GPU = %d",
             args.per_gpu_train_batch_size)
    log_info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (xm.xrt_world_size() if args.local_rank != -1 else 1))
    log_info("  Gradient Accumulation steps = %d",
             args.gradient_accumulation_steps)
    log_info("  Total optimization steps = %d", t_total)

    try:
        with open(os.path.join(args.model_name_or_path, 'step.txt'), 'r') as c:
            global_step = int(c.readline())
    except OSError as e:
        global_step = 0

    moving_loss = MovingLoss(10000 // args.logging_steps)

    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=not xm.is_master_ordinal())
    try:
        for epoch in train_iterator:
            p_train_dataloader = pl.ParallelLoader(train_dataloader,
                                                   [args.device])
            epoch_iterator = tqdm(p_train_dataloader.per_device_loader(
                args.device),
                                  total=len(train_dataloader),
                                  desc="Iteration",
                                  disable=not xm.is_master_ordinal())

            model.train()
            for step, batch in enumerate(epoch_iterator):
                optimizer.zero_grad()
                inputs, labels = mask_tokens(
                    batch, tokenizer, args) if args.mlm else (batch, batch)
                outputs = model(
                    inputs, masked_lm_labels=labels) if args.mlm else model(
                        inputs, labels=labels)
                loss = outputs[
                    0]  # model outputs are always tuple in pytorch-transformers (see doc)

                if args.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                    xm.optimizer_step(optimizer, barrier=True)
                    scheduler.step()
                    global_step += 1
                    tracker.add(args.train_batch_size)

                    if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        ls = loss.item(
                        )  # weird. if you call loss.item() only in one process, the whole thing hangs. So call on every and log in one.
                        moving_loss.add(ls)
                        summary_write('lr',
                                      scheduler.get_last_lr()[0], global_step)
                        epoch_iterator.set_postfix(
                            MovingLoss=f'{moving_loss.loss:.2f}',
                            Perplexity=
                            f'{torch.exp(torch.tensor(moving_loss.loss)):.2f}')

                    if args.save_steps > 0 and global_step % args.save_steps == 0:
                        save_state(args, model, tokenizer, global_step)

                if step >= 2:  # TPU seems to like consistent epoch lenght
                    if xm.is_master_ordinal():
                        print(met.metrics_report())
                    exit(0)

                #    epoch_iterator.close()
                #    break

                if args.max_steps > 0 and step > args.max_steps:
                    epoch_iterator.close()
                    break

            # evaluate once in an epoch
            if args.evaluate_during_training:
                results = evaluate(args, model, tokenizer,
                                   f"checkpoint-{global_step}")
                log_info(f"Eval {results}")
                for key, value in results.items():
                    summary_write("eval_{}".format(key), value, global_step)

            # reload dataset every args.reload_data_file epochs
            if args.reload_data_file and (epoch +
                                          1) % args.reload_data_file == 0:
                train_dataloader = build_dataloader(args, tokenizer)

            # that's very slow on TPU
            #print_sample(model, tokenizer, args.device, args)

    except (KeyboardInterrupt, SystemExit):
        save_state(args, model, tokenizer, global_step)
        raise

    save_state(args, model, tokenizer, global_step)

    return global_step, moving_loss.loss
def train_fn(df):
    size = 1; torch.manual_seed(42)

    df = shuffle(df)
    split = np.int32(SPLIT*len(df))
    val_df, train_df = df[split:], df[:split]

    val_df = val_df.reset_index(drop=True)
    val_set = QuoraDataset(val_df, tokenizer)
    val_sampler = DistributedSampler(val_set, num_replicas=8,
                                     rank=xm.get_ordinal(), shuffle=True)

    train_df = train_df.reset_index(drop=True)
    train_set = QuoraDataset(train_df, tokenizer)
    train_sampler = DistributedSampler(train_set, num_replicas=8,
                                       rank=xm.get_ordinal(), shuffle=True)
    
    val_loader = DataLoader(val_set, VAL_BATCH_SIZE,
                            sampler=val_sampler, num_workers=0, drop_last=True)

    train_loader = DataLoader(train_set, BATCH_SIZE,
                              sampler=train_sampler, num_workers=0, drop_last=True)

    device = xm.xla_device()
    network = Roberta().to(device)
    optimizer = Adam([{'params': network.roberta.parameters(), 'lr': LR[0]*size},
                      {'params': network.dense_output.parameters(), 'lr': LR[1]*size}])

    val_losses, val_f1s = [], []
    train_losses, train_f1s = [], []
    
    start = time.time()
    xm.master_print("STARTING TRAINING ...\n")

    for epoch in range(EPOCHS):

        batch = 1
        network.train()
        fonts = (fg(48), attr('reset'))
        xm.master_print(("EPOCH %s" + str(epoch+1) + "%s") % fonts)

        val_parallel = pl.ParallelLoader(val_loader, [device]).per_device_loader(device)
        train_parallel = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
        
        for train_batch in train_parallel:
            train_targ, train_in, train_att = train_batch
            
            network = network.to(device)
            train_in = train_in.to(device)
            train_att = train_att.to(device)
            train_targ = train_targ.to(device)

            train_preds = network.forward(train_in, train_att)
            train_loss = bce(train_preds, train_targ)/len(train_preds)
            train_f1 = f1_score(train_preds, train_targ.squeeze(dim=1))

            optimizer.zero_grad()
            train_loss.backward()
            xm.optimizer_step(optimizer)
            
            end = time.time()
            batch = batch + 1
            is_print = batch % 10 == 1
            f1 = np.round(train_f1.item(), 3)
            if is_print: print_metric(f1, batch, None, start, end, metric="F1", typ="Train")

        val_loss, val_f1, val_points = 0, 0, 0

        network.eval()
        with torch.no_grad():
            for val_batch in val_parallel:
                targ, val_in, val_att = val_batch

                targ = targ.to(device)
                val_in = val_in.to(device)
                val_att = val_att.to(device)
                network = network.to(device)
                pred = network.forward(val_in, val_att)

                val_points += len(targ)
                val_loss += bce(pred, targ).item()
                val_f1 += f1_score(pred, targ.squeeze(dim=1)).item()*len(pred)
        
        end = time.time()
        val_f1 /= val_points
        val_loss /= val_points
        f1 = xm.mesh_reduce('f1', val_f1, lambda x: sum(x)/len(x))
        loss = xm.mesh_reduce('loss', val_loss, lambda x: sum(x)/len(x))
        print_metric(np.round(f1, 3), None, epoch, start, end, metric="F1", typ="Val")
    
        xm.master_print("")
        val_f1s.append(f1); train_f1s.append(train_f1.item())
        val_losses.append(loss); train_losses.append(train_loss.item())

    xm.master_print("ENDING TRAINING ...")
    xm.save(network.state_dict(), MODEL_SAVE_PATH); del network; gc.collect()
    
    metric_lists = [val_losses, train_losses, val_f1s, train_f1s]
    metric_names = ['val_loss_', 'train_loss_', 'val_f1_', 'train_f1_']
    
    for i, metric_list in enumerate(metric_lists):
        for j, metric_value in enumerate(metric_list):
            torch.save(metric_value, metric_names[i] + str(j) + '.pt')
    def evaluate(self, model, dataloaders, max_batches, test=False):
        """Run evaluation code.

        :param model: PT model
        :param dataloaders: list of PT dataloaders
        :param max_batches: Scalar
        :param test: boolean
        :return:
        """
        # enable eval mode
        model.zero_grad()
        model.eval()

        # copy properties for forward overrides
        self.copy_trainer_model_properties(model)

        # disable gradients to save memory
        torch.set_grad_enabled(False)

        # bookkeeping
        outputs = []

        # run validation
        for dataloader_idx, dataloader in enumerate(dataloaders):
            dl_outputs = []

            # on TPU we have to wrap it under the ParallelLoader
            if self.use_tpu:
                device = xm.xla_device()
                dataloader = xla_pl.ParallelLoader(dataloader, [device])
                dataloader = dataloader.per_device_loader(device)

            for batch_idx, batch in enumerate(dataloader):

                if batch is None:  # pragma: no cover
                    continue

                # stop short when on fast_dev_run (sets max_batch=1)
                if batch_idx >= max_batches:
                    break

                # -----------------
                # RUN EVALUATION STEP
                # -----------------
                output = self.evaluation_forward(model, batch, batch_idx,
                                                 dataloader_idx, test)

                # track outputs for collation
                dl_outputs.append(output)

                # batch done
                if batch_idx % self.progress_bar_refresh_rate == 0:
                    if test:
                        self.test_progress_bar.update(
                            self.progress_bar_refresh_rate)
                    else:
                        self.val_progress_bar.update(
                            self.progress_bar_refresh_rate)
                        self.main_progress_bar.update(
                            self.progress_bar_refresh_rate)
            outputs.append(dl_outputs)

        eval_results = {}

        # with a single dataloader don't pass an array
        if len(dataloaders) == 1:
            outputs = outputs[0]

        # give model a chance to do something with the outputs (and method defined)
        model = self.get_model()
        if test and self.is_overriden('test_end'):
            eval_results = model.test_end(outputs)
        elif self.is_overriden('validation_end'):
            eval_results = model.validation_end(outputs)

        # enable train mode again
        model.train()

        # enable gradients to save memory
        torch.set_grad_enabled(True)

        return eval_results
    def run_training_epoch(self):

        # Epoch begin callbacks
        self.on_epoch_start()

        # before epoch hook
        if self.is_function_implemented('on_epoch_start'):
            model = self.get_model()
            with self.profiler.profile('on_epoch_start'):
                model.on_epoch_start()

        # reset train dataloader
        if self.reload_dataloaders_every_epoch:
            self.reset_train_dataloader(self.get_model())

        # track local dataloader so TPU can wrap each epoch
        train_dataloader = self.train_dataloader

        # on TPU we have to wrap it under the ParallelLoader
        if self.use_tpu:
            device = xm.xla_device()
            train_dataloader = xla_pl.ParallelLoader(train_dataloader,
                                                     [device])
            train_dataloader = train_dataloader.per_device_loader(device)

        # run epoch
        for batch_idx, batch in self.profiler.profile_iterable(
                enumerate(train_dataloader), "get_train_batch"):
            # stop epoch if we limited the number of training batches
            if batch_idx >= self.num_training_batches:
                break

            self.batch_idx = batch_idx

            model = self.get_model()
            model.global_step = self.global_step

            # ---------------
            # RUN TRAIN STEP
            # ---------------
            output = self.run_training_batch(batch, batch_idx)
            batch_result, grad_norm_dic, batch_step_metrics = output

            # when returning -1 from train_step, we end epoch early
            early_stop_epoch = batch_result == -1

            # ---------------
            # RUN VAL STEP
            # ---------------
            is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
            can_check_epoch = (self.current_epoch +
                               1) % self.check_val_every_n_epoch == 0
            should_check_val = not self.disable_validation and can_check_epoch
            should_check_val = should_check_val and (is_val_check_batch
                                                     or early_stop_epoch)

            # fast_dev_run always forces val checking after train batch
            if self.fast_dev_run or should_check_val:
                self.run_evaluation(test_mode=self.testing)

                if self.enable_early_stop:
                    self.early_stop_callback.check_metrics(
                        self.callback_metrics)

            # when logs should be saved
            should_save_log = (
                batch_idx +
                1) % self.log_save_interval == 0 or early_stop_epoch
            if should_save_log or self.fast_dev_run:
                if self.proc_rank == 0 and self.logger is not None:
                    self.logger.save()

            # when metrics should be logged
            should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
            if should_log_metrics or self.fast_dev_run:
                # logs user requested information to logger
                self.log_metrics(batch_step_metrics, grad_norm_dic)

            # progress global step according to grads progress
            if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
                self.global_step += 1
            self.total_batch_idx += 1

            # max steps reached, end training
            if self.max_steps is not None and self.max_steps == self.global_step:
                break

            # end epoch early
            # stop when the flag is changed or we've gone past the amount
            # requested in the batches
            if early_stop_epoch or self.fast_dev_run:
                break

        # epoch end hook
        if self.is_function_implemented('on_epoch_end'):
            model = self.get_model()
            with self.profiler.profile('on_epoch_end'):
                model.on_epoch_end()

        # Epoch begin callbacks
        self.on_epoch_end()
Esempio n. 24
0
    def run():

        torch.manual_seed(seed)

        device = xm.xla_device()
        model = MX.to(device)

        # DataLoaders
        train_dataset = TweetDataset(args=args,
                                     df=train_df,
                                     mode="train",
                                     fold=args.fold_index,
                                     tokenizer=tokenizer)
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   sampler=train_sampler,
                                                   drop_last=False,
                                                   num_workers=2)

        valid_dataset = TweetDataset(args=args,
                                     df=train_df,
                                     mode="valid",
                                     fold=args.fold_index,
                                     tokenizer=tokenizer)
        valid_sampler = torch.utils.data.distributed.DistributedSampler(
            valid_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False)
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=args.batch_size,
                                  sampler=valid_sampler,
                                  num_workers=1,
                                  drop_last=False)

        param_optimizer = list(model.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {
                'params': [
                    p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.001
            },
            {
                'params': [
                    p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            },
        ]

        num_train_steps = int(num_train_dpoints / args.batch_size /
                              xm.xrt_world_size() * args.epochs)

        optimizer = AdamW(optimizer_parameters,
                          lr=args.learning_rate * xm.xrt_world_size())

        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=0, num_training_steps=num_train_steps)

        xm.master_print("Training is Starting ...... ")
        best_jac = 0
        #early_stopping = utils.EarlyStopping(patience=2, mode="max", verbose=True)

        for epoch in range(args.epochs):
            para_loader = pl.ParallelLoader(train_loader, [device])
            train_loss = train(args, para_loader.per_device_loader(device),
                               model, device, optimizer, scheduler, epoch, f)

            para_loader = pl.ParallelLoader(valid_loader, [device])
            valid_jac = valid(args, para_loader.per_device_loader(device),
                              model, device, tokenizer, epoch, f)

            jac = xm.mesh_reduce("jac_reduce", valid_jac, reduce_fn)
            xm.master_print(f"**** Epoch {epoch+1} **==>** Jaccard = {jac}")

            log_ = f"**** Epoch {epoch+1} **==>** Jaccard = {jac}"

            f.write(log_ + "\n\n")

            if jac > best_jac:
                xm.master_print("**** Model Improved !!!! Saving Model")
                xm.save(
                    model.state_dict(),
                    os.path.join(args.save_path, f"fold_{args.fold_index}"))
                best_jac = jac
Esempio n. 25
0
def train_tpu():
    torch.manual_seed(1)

    def get_dataset():
        fold_number = 0
        train_ = pd.read_csv(args.train_fold)
        train = ShopeeDataset(
            train_[train_['fold'] != fold_number].reset_index(drop=True))
        test = ShopeeDataset(
            train_[train_['fold'] != fold_number].reset_index(drop=True),
            transform=args.test_args)
        return train, test

    # Using the serial executor avoids multiple processes
    # to download the same data.
    train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               sampler=train_sampler,
                                               num_workers=args.num_workers,
                                               drop_last=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              drop_last=True)

    # Scale learning rate to num cores
    learning_rate = 1e-5 * xm.xrt_world_size()

    # Get loss function, optimizer, and model
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.CrossEntropyLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, (data, label) in enumerate(loader):
            optimizer.zero_grad()
            output = model(image=data,
                           label=label,
                           get_embedding=args.get_embeddings)
            loss = loss_fn(output, label)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(args.batch_size)
            if x % 20 == 0:
                print(
                    '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'
                    .format(xm.get_ordinal(), x, loss.item(), tracker.rate(),
                            tracker.global_rate(), time.asctime()),
                    flush=True)

    def test_loop_fn(loader):
        model.eval()
        for x, (data, label) in enumerate(loader):
            output = model(image=data,
                           label=label,
                           get_embedding=args.get_embeddings)
            loss = loss_fn(output, label)
            if x % 20 == 0:
                print('[xla:{}]({}) Loss={:.5f}'.format(
                    xm.get_ordinal(), x, loss.item()),
                      flush=True)

    for epoch in range(1, args.n_epochs + 1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print("Finished training epoch {}".format(epoch))

        para_loader = pl.ParallelLoader(test_loader, [device])
        test_loop_fn(para_loader.per_device_loader(device))
Esempio n. 26
0
def train_imagenet():
    print('==> Preparing data..')
    img_dim = get_model_property('img_dim')
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    device = xm.xla_device()
    model = get_model_property('model_fn')().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(),
                          lr=FLAGS.lr,
                          momentum=FLAGS.momentum,
                          weight_decay=1e-4)
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         xm.xrt_world_size())
    lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
        optimizer,
        scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
        scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
        scheduler_divide_every_n_epochs=getattr(
            FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
        num_steps_per_epoch=num_training_steps_per_epoch,
        summary_writer=writer)
    loss_fn = nn.CrossEntropyLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print("Finished training epoch {}".format(epoch))

        para_loader = pl.ParallelLoader(test_loader, [device])
        accuracy = test_loop_fn(para_loader.per_device_loader(device))
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         epoch)

        if FLAGS.metrics_debug:
            print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    return accuracy
Esempio n. 27
0
    def prediction_loop(
            self,
            dataloader: DataLoader,
            description: str,
            prediction_loss_only: Optional[bool] = None) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.

        Works both with or without labels.
        """
        if hasattr(self, "_prediction_loop"):
            warnings.warn(
                "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
                FutureWarning,
            )
            return self._prediction_loop(
                dataloader,
                description,
                prediction_loss_only=prediction_loss_only)

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = self.model
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        batch_size = dataloader.batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
        eval_losses: List[float] = []
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
        model.eval()

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(
                dataloader,
                [self.args.device]).per_device_loader(self.args.device)

        if self.args.past_index >= 0:
            self._past = None

        for inputs in tqdm(dataloader, desc=description):
            loss, logits, labels = self.prediction_step(
                model, inputs, prediction_loss_only)
            if loss is not None:
                eval_losses.append(loss)
            if logits is not None:
                preds = logits if preds is None else torch.cat(
                    (preds, logits), dim=0)
            if labels is not None:
                label_ids = labels if label_ids is None else torch.cat(
                    (label_ids, labels), dim=0)

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = self.distributed_concat(
                    preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = self.distributed_concat(
                    label_ids,
                    num_total_examples=self.num_examples(dataloader))
        elif is_torch_tpu_available():
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids,
                                           torch.cat)

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(
                EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds,
                                label_ids=label_ids,
                                metrics=metrics)
Esempio n. 28
0
def run():
    '''
    Entire training loop
        - Create DataLoaders
        - Define Training Configuration
        - Launch Training Loop
    '''

    # Num of available TPU cores
    if config.TPUs:
        n_TPUs = xm.xrt_world_size()
        DEVICE = xm.xla_device()
    else:
        DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(DEVICE)
    
    # Read Data
    
    # df1 = pd.read_csv('data/jigsaw-toxic-comment-train.csv', usecols=['comment_text', 'toxic'])
    # df2 = pd.read_csv('data/jigsaw-unintended-bias-train.csv', usecols=['comment_text', 'toxic'], engine='python') # don't know why it was breaking with default C parser
    # df_train = df1 # pd.concat([df1,df2], axis=0).reset_index(drop=True)
    # df_valid = pd.read_csv('data/validation.csv')
    
    # Subsample
    df_train = pd.read_csv('data/jigsaw-toxic-comment-train-small.csv', usecols=['comment_text', 'toxic'])
    df_valid = pd.read_csv('data/validation-small.csv', usecols=['comment_text', 'toxic']) 

    # Preprocess
    
    train_dataset = dataset.BERTDataset(
        comment=df_train.comment_text.values,
        target=df_train.toxic.values
    )

    valid_dataset = dataset.BERTDataset(
        comment=df_valid.comment_text.values,
        target=df_valid.toxic.values
    )

    drop_last=False
    train_sampler, valid_sampler = None, None
    if config.TPUs:
        drop_last=True
        train_sampler = DistributedSampler(
            train_dataset, 
            num_replicas=n_TPUs,
            rank=xm.get_ordinal(),
            shuffle=True
        )
        valid_sampler = DistributedSampler(
            valid_dataset, 
            num_replicas=n_TPUs,
            rank=xm.get_ordinal(),
            shuffle=True
        )


    # Create Data Loaders

    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=config.TRAIN_BATCH_SIZE,
        num_workers=4,
        drop_last=drop_last,
        sampler=train_sampler
    )


    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        num_workers=1,
        drop_last=drop_last,
        sampler=valid_sampler
    )

    # Machine Configuration

    if config.MODEL == 'bert':
        model = BERTBaseUncased()
    elif config.MODEL == 'distil-bert':
        model = DistilBERTBaseUncased()
    else:
        print('Model chosen in config not valid')
        exit()
    model.to(device)
    
    # Optimizer Configuration 

    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    ]

    lr = config.LR
    num_train_steps = int(len(df_train) / config.TRAIN_BATCH_SIZE * config.EPOCHS)
    # TODO: why do the LR increases because of a distributed training ?
    if config.TPUs:
        num_train_steps /= n_TPUs
        lr *= n_TPUs

    optimizer = AdamW(optimizer_parameters, lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )

    if not config.TPUs:
        if N_GPU > 1:
            model = nn.DataParallel(model)
    
    # Training loop

    best_score = 0
    
    for epoch in range(config.EPOCHS):
    
        if config.TPUs:
            train_loader = pl.ParallelLoader(train_data_loader, [device])
            valid_loader = pl.ParallelLoader(valid_data_loader, [device])
            train_fn(train_loader.per_device_loader(device), model, optimizer, device, scheduler)
            outputs, targets = eval_fn(valid_loader.per_device_loader(device), model, device)

        else:
            train_fn(train_data_loader, model, optimizer, device, scheduler)
            outputs, targets = eval_fn(valid_data_loader, model, device)
        
        targets = np.array(targets) >= 0.5 # TODO: why ?
        auc_score = metrics.roc_auc_score(targets, outputs)
            
        # Save if best
        print(f"AUC Score = {auc_score}")
        if auc_score > best_score:
            if not config.TPUs:
                torch.save(model.state_dict(), config.MODEL_PATH)
            else:
                xm.save(model.state_dict(), config.MODEL_PATH)
            best_score = auc_score
Esempio n. 29
0
def train_mnist():
    torch.manual_seed(1)

    """
    tpu 를 쓴다하면 dataset 에 할 일
    train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    """
    def get_dataset():
        norm = transforms.Normalize((0.1307,), (0.3081,))
        train_dataset = datasets.MNIST(
            FLAGS['datadir'],
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), norm]))
        test_dataset = datasets.MNIST(
            FLAGS['datadir'],
            train=False,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), norm]))

        return train_dataset, test_dataset

    # Using the serial executor avoids multiple processes to
    # download the same data.
    train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=FLAGS['batch_size'],
        sampler=train_sampler,
        num_workers=FLAGS['num_workers'],
        drop_last=True)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=FLAGS['batch_size'],
        shuffle=False,
        num_workers=FLAGS['num_workers'],
        drop_last=True)

    # Scale learning rate to world size
    lr = FLAGS['learning_rate'] * xm.xrt_world_size()

    # Get loss function, optimizer, and model
    """
    tpu 쓴다하면 device 가 
    device = xm.xla_device()
    model = xmp.MpModelWrapper(MNIST()).to(device)
    """
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS['momentum'])
    loss_fn = nn.NLLLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            # tpu 쓴다하면 optimizer 에 xm.optimizer_step(optimizer)
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS['batch_size'])
            if x % FLAGS['log_steps'] == 0:
                print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
                    xm.get_ordinal(), x, loss.item(), tracker.rate(),
                    tracker.global_rate(), time.asctime()), flush=True)

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        data, pred, target = None, None, None
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        print('[xla:{}] Accuracy={:.2f}%'.format(
            xm.get_ordinal(), accuracy), flush=True)
        return accuracy, data, pred, target

    # Train and eval loops
    accuracy = 0.0
    data, pred, target = None, None, None
    for epoch in range(1, FLAGS['num_epochs'] + 1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print("Finished training epoch {}".format(epoch))

        para_loader = pl.ParallelLoader(test_loader, [device])
        accuracy, data, pred, target = test_loop_fn(para_loader.per_device_loader(device))
        if FLAGS['metrics_debug']:
            xm.master_print(met.metrics_report(), flush=True)

    return accuracy, data, pred, target
Esempio n. 30
0
def train(rank, args):
    print('enter train @ %s' % (rank), flush=True)
    args.rank = rank
    torch.manual_seed(42)

    tokenizer = get_tokenizer(args)
    args.vocab_size = tokenizer._tokenizer.get_vocab_size()

    train_dataset = get_dataset(args)

    if args.total_num_updates < 100:
        args.total_num_updates = len(train_dataset) * args.total_num_updates

    if args.warmup_updates < 1:
        args.warmup_updates = int(args.total_num_updates * args.warmup_updates)
    else:
        args.warmup_updates = int(args.warmup_updates)

    train_sampler = None
    if args.gpus:
        dist.init_process_group('nccl', rank=rank, world_size=args.world_size)
        if args.gpus > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=args.gpus,
                rank=rank,
                shuffle=False)

    else:
        rank = xm.get_ordinal()
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=rank,
                shuffle=False)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size
        if not hasattr(train_dataset, '__getbatch__') else None,
        sampler=train_sampler,
        pin_memory=True,
        shuffle=False,
        num_workers=args.num_workers)

    eval_loader = None
    if args.eval_dir:

        eval_sampler = None
        if args.gpus:
            dist.init_process_group('nccl',
                                    rank=rank,
                                    world_size=args.world_size)
            if args.gpus > 1:
                traieval_samplern_sampler = torch.utils.data.distributed.DistributedSampler(
                    train_dataset,
                    num_replicas=args.gpus,
                    rank=rank,
                    shuffle=False)

        else:
            rank = xm.get_ordinal()
            if xm.xrt_world_size() > 1:
                eval_sampler = torch.utils.data.distributed.DistributedSampler(
                    train_dataset,
                    num_replicas=xm.xrt_world_size(),
                    rank=rank,
                    shuffle=False)

        eval_dataset = get_eval_dataset(args)
        eval_loader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=args.batch_size
            if not hasattr(train_dataset, '__getbatch__') else None,
            sampler=eval_sampler,
            pin_memory=True,
            shuffle=False,
            num_workers=args.num_workers)

    if args.gpus:
        assert apex_enabled
        torch.cuda.set_device(rank)

        ##########################
        ##
        ##  Model Creation
        ##
        ##########################
        model = get_model(args)

        model.cuda(rank)

        device = torch.device('cuda:' + str(rank))

        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = apex.optimizers.FusedAdam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay),

            # use this function to set extra optimizer arguments,
            # see model_get_parameters
            betas=(0.9, 0.999),
            eps=1e-6,
            lr=args.lr,
            weight_decay=args.weight_decay)

        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        model = DDP(model)
        batches = train_loader

    else:
        assert tpu_enabled
        device = xm.xla_device()

        ##########################
        ##
        ##  Model Creation
        ##
        ##########################

        model = get_model(args)

        ##########################
        ##
        ##  For shared parameters, TPU requires modules to be tied after .to(device)
        ##  So we first find the shared parameters first
        ##
        ##########################

        shared_parameters = {
            e[0]: e[1:]
            for e in _catalog_shared_params(model)
        }

        model.to(device)

        do_share_parameters_again(model, shared_parameters, log=rank == 0)

        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = optim.Adam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay),

            # use this function to set extra optimizer arguments,
            # see model_get_parameters
            lr=args.lr,
            weight_decay=args.weight_decay)

        writer = None
        if xm.is_master_ordinal():
            writer = test_utils.get_summary_writer(args.save_dir)

        xm.rendezvous("load_checkpoint")  # wait for all workers
        xm.mark_step()

        # tracker = xm.RateTracker()
    if args.restore_file:
        states = torch.load(args.restore_file, map_location=device)
        for k, v in list(states.items()):
            if k.startswith('module.'):
                del states[k]
                k = k[7:]
                states[k] = v
            if k.endswith('position_ids'):
                del states[k]
                states[k[:-12] + 'position_embeddings'] = v
        try:
            model.load_state_dict(states)
        except Exception as err:
            import traceback
            traceback.print_exc()
            model.load_state_dict(states, strict=False)

    model.train()

    if args.anomaly_detection and rank == 0:
        torch.set_anomaly_enabled(True)

    ##########################
    ##
    ##  Init LR Scheduler
    ##
    ##########################

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_updates,
        num_training_steps=args.total_num_updates,
    )

    step_i = 0

    err = None
    try:
        if rank == 0:
            pbar = tqdm(total=args.total_num_updates)
        while step_i < args.total_num_updates:
            if not args.gpus:
                batches = pl.ParallelLoader(train_loader,
                                            [device]).per_device_loader(device)
            for sample in batches:
                step_i += 1
                if step_i > args.total_num_updates:
                    break

                report_step = step_i % args.log_interval == 0

                while True:  # the loop only for apex Gradient Overflow
                    optimizer.zero_grad()

                    total_loss, log = get_loss(model,
                                               sample,
                                               args=args,
                                               device=device,
                                               gpu=args.gpus,
                                               report=report_step)

                    if args.gpus:
                        default_optimizer_step = optimizer.step

                        with amp.scale_loss(total_loss,
                                            optimizer) as scaled_loss:
                            scaled_loss.backward()

                        # If Amp detects an overflow, it patches optimizer.step.  In other words, if optimizer.step
                        # was left unpatched, there was no overflow, and we don't need to replay.
                        if optimizer.step is default_optimizer_step:
                            optimizer.step()
                            break

                        optimizer.step(
                        )  # If an overflow was detected, "optimizer.step" is the patched call, which does
                        # nothing but restore optimizer.step to default_optimizer_step.
                        if rank == 0:
                            print(
                                "Overflowed, reducing loss scale and replaying batch.",
                                flush=True)

                    else:
                        total_loss.backward()
                        xm.optimizer_step(optimizer)
                        xm.mark_step()

                        break

                scheduler.step()

                if report_step:
                    if 'loss' not in log:
                        log['loss'] = total_loss

                    if args.gpus:
                        if rank == 0:
                            pbar.set_description(format_log(
                                log, log_formatter))
                    else:
                        xm.add_step_closure(_train_update,
                                            args=(log, log_formatter))

                    if args.report_metrics:
                        xm.master_print(met.metrics_report())

                if rank == 0:
                    pbar.update(1)

        if eval_loader is not None:
            model.eval()
            if not args.gpus:
                batches = pl.ParallelLoader(eval_loader,
                                            [device]).per_device_loader(device)
            with torch.no_grad():
                record = OrderedDict()

                for sample in batches:
                    evaluate(model,
                             sample,
                             args=args,
                             device=device,
                             record=record,
                             gpu=args.gpus,
                             report=report_step)

                post_evaluate(record, args=args)

            import json
            print('', flush=True)
            print(json.dumps(record), flush=True)
            print('', flush=True)

    except Exception as _err:
        err = _err
    finally:
        save_fn = os.path.join(args.save_dir, 'checkpoint_final.pt')
        folder = os.path.split(os.path.abspath(save_fn))[0]
        os.makedirs(folder, exist_ok=True)
        if rank == 0 and args.gpus:
            torch.save(model.state_dict(), save_fn)
            if err:
                raise err
        else:
            xm.save(model.state_dict(), save_fn)
            if err:
                raise err