Example #1
0
    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        prediction_loss_only=False,
        tb_writer: Optional["SummaryWriter"] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None,
    ):
        """
        Trainer is a simple but feature-complete training and eval loop for PyTorch,
        optimized for Transformers.

        Args:
            prediction_loss_only:
                (Optional) in evaluation and prediction, only return the loss
        """
        self.model = model.to(args.device)
        self.args = args
        self.data_collator = data_collator if data_collator is not None else default_data_collator
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        self.prediction_loss_only = prediction_loss_only
        self.optimizers = optimizers
        if tb_writer is not None:
            self.tb_writer = tb_writer
        elif is_tensorboard_available() and self.is_world_master():
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
        if is_wandb_available():
            self._setup_wandb()
        else:
            logger.info(
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
            )
        set_seed(self.args.seed)
        # Create output directory if needed
        if self.is_world_master():
            os.makedirs(self.args.output_dir, exist_ok=True)
        if is_torch_tpu_available():
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
            self.data_collator = self.data_collator.collate_batch
            warnings.warn(
                (
                    "The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
                    + "with a `collate_batch` are deprecated and won't be supported in a future version."
                ),
                FutureWarning,
            )
Example #2
0
 def _log(self,
          logs: Dict[str, float],
          iterator: Optional[tqdm] = None) -> None:
     if self.epoch is not None:
         logs["epoch"] = self.epoch
     if self.global_step is None:
         # when logging evaluation metrics without training
         self.global_step = 0
     if self.tb_writer:
         for k, v in logs.items():
             if isinstance(v, (int, float)):
                 self.tb_writer.add_scalar(k, v, self.global_step)
             else:
                 logger.warning(
                     "Trainer is attempting to log a value of "
                     '"%s" of type %s for key "%s" as a scalar. '
                     "This invocation of Tensorboard's writer.add_scalar() "
                     "is incorrect so we dropped this attribute.",
                     v,
                     type(v),
                     k,
                 )
         self.tb_writer.flush()
     if is_wandb_available():
         if self.is_world_master():
             wandb.log(logs, step=self.global_step)
     output = {**logs, **{"step": self.global_step}}
     if iterator is not None:
         iterator.write(output)
     else:
         logger.info(output)
Example #3
0
 def _log(self, logs: Dict[str, float]) -> None:
     if self.tb_writer:
         with self.tb_writer.as_default():
             for k, v in logs.items():
                 tf.summary.scalar(k, v, step=self.global_step)
         self.tb_writer.flush()
     if is_wandb_available():
         wandb.log(logs, step=self.global_step)
     output = {**logs, **{"step": self.global_step}}
     logger.info(output)
Example #4
0
    def log(self,
            logs: Dict[str, float],
            iterator: Optional[tqdm] = None) -> None:
        """
        Log :obj:`logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (:obj:`Dict[str, float]`):
                The values to log.
            iterator (:obj:`tqdm`, `optional`):
                A potential tqdm progress bar to write the logs on.
        """
        if hasattr(self, "_log"):
            warnings.warn(
                "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
                FutureWarning,
            )
            return self._log(logs, iterator=iterator)

        if self.epoch is not None:
            logs["epoch"] = self.epoch
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
        if self.tb_writer:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
            self.tb_writer.flush()
        if is_wandb_available():
            if self.is_world_master():
                wandb.log(logs, step=self.global_step)
        output = {**logs, **{"step": self.global_step}}
        if iterator is not None:
            iterator.write(output)
        else:
            logger.info(output)
Example #5
0
    def __init__(
        self,
        model: TFPreTrainedModel,
        args: TFTrainingArguments,
        train_dataset: Optional[tf.data.Dataset] = None,
        eval_train_dataset: Optional[tf.data.Dataset] = None,
        eval_dataset: Optional[tf.data.Dataset] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        prediction_loss_only=False,
        tb_writer: Optional[tf.summary.SummaryWriter] = None,
        optimizers: Tuple[
            tf.keras.optimizers.Optimizer,
            tf.keras.optimizers.schedules.LearningRateSchedule] = None,
        train_size: Optional[int] = None,
    ):
        self.model = model
        self.args = args
        self.train_dataset = train_dataset
        self.eval_train_dataset = eval_train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        self.prediction_loss_only = prediction_loss_only
        self.optimizers = optimizers
        self.gradient_accumulator = GradientAccumulator()
        self.global_step = 0
        self.epoch_logging = 0

        self.num_train_examples = train_size

        if tb_writer is not None:
            self.tb_writer = tb_writer
        else:
            self.tb_writer = tf.summary.create_file_writer(
                self.args.logging_dir)
        if is_wandb_available():
            self._setup_wandb()
        else:
            logger.info(
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
            )
Example #6
0
    _has_tensorboard = True
except ImportError:
    try:
        from tensorboardX import SummaryWriter

        _has_tensorboard = True
    except ImportError:
        _has_tensorboard = False


def is_tensorboard_available():
    return _has_tensorboard


if is_wandb_available():
    import wandb

logger = logging.getLogger(__name__)


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # ^^ safe to call this function even if cuda is not available


@contextmanager
def torch_distributed_zero_first(local_rank: int):
Example #7
0
def main():
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    model_args: ModelArguments
    data_args: DataTrainingArguments
    training_args: TrainingArguments
    if training_args.fp16:
        try:
            import apex

            apex.amp.register_half_function(torch, "einsum")
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
    # if training_args.do_eval and not training_args.do_train and not data_args.predictions_folder:
    #     raise ValueError("Supply predictions folder destination to save the predictions!")
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
        if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )

    logger.debug(model_args)
    logger.debug(training_args)
    logger.debug(data_args)
    # raise NotImplementedError
    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            f"Use --overwrite_output_dir to overcome.")

    # Set seed
    set_seed(training_args.seed)
    if training_args.local_rank not in [-1, 0]:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()
    tokenizer = get_tokenizer(model_args.model_name_or_path,
                              do_lower_case=False)
    if data_args.model_parallel == 4:
        model = T5ForConditionalGeneration4WayParallel.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
        )
    elif data_args.model_parallel == 2:
        model = T5ForConditionalGeneration2WayParallel.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
        )
    elif data_args.model_parallel is None:
        model = T5ForConditionalGeneration.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
        )
    else:
        raise ValueError(
            f"Can only have no, 2way or 4way model parallelism! (expected: {data_args.model_parallel})"
        )
    if training_args.local_rank == 0:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()
    # Get datasets
    if training_args.do_eval and training_args.local_rank in [-1, 0]:
        eval_dataset, examples = get_dataset(data_args.eval_file_path,
                                             tokenizer,
                                             data_args,
                                             evaluate=True)
    else:
        eval_dataset, examples = None, None
    # Training
    if training_args.do_train:
        if training_args.local_rank in [-1, 0]:
            train_dataset, _ = get_dataset(data_args.train_file_path,
                                           tokenizer, data_args)
            torch.save(train_dataset, 'features.bin')
        else:
            torch.distributed.barrier()
            train_dataset = None

        if training_args.local_rank == 0:
            torch.distributed.barrier()

        else:
            train_dataset = torch.load('features.bin')
        # Initialize our Trainer
        if data_args.model_parallel:
            trainer = MyTrainer(model=model,
                                args=training_args,
                                train_dataset=train_dataset,
                                eval_dataset=eval_dataset,
                                data_collator=collate_training,
                                prediction_loss_only=True)
            model.set_parallel()
        else:
            trainer = Trainer(model=model,
                              args=training_args,
                              train_dataset=train_dataset,
                              eval_dataset=eval_dataset,
                              data_collator=collate_training,
                              prediction_loss_only=True)
        trainer.train(model_path=model_args.model_name_or_path if os.path.
                      isdir(model_args.model_name_or_path) else None)
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Evaluation
    if training_args.do_eval and training_args.local_rank in [-1, 0]:
        if training_args.do_train:
            model_path = os.path.basename(training_args.output_dir)
        else:
            model_path = os.path.basename(model_args.model_name_or_path)
        checkpoints = [training_args.output_dir]
        if data_args.eval_all_checkpoints and training_args.do_train:
            logger.info(
                "Loading checkpoints saved during training for evaluation")
            checkpoints = list(
                os.path.dirname(c) for c in sorted(
                    glob.glob(training_args.output_dir + "/**/" + WEIGHTS_NAME,
                              recursive=True)))
            # logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce model loading logs

        logger.info(f"Evaluate the following checkpoints: {checkpoints}")
        results = {}

        logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)
        for checkpoint in checkpoints:
            # Reload the model
            global_step = checkpoint.split("-")[-1]
            if not all(s in string.digits for s in global_step):
                global_step = ''
            # no model parallelism here (didnt check model.generate)
            model = T5ForConditionalGeneration.from_pretrained(checkpoint)
            device = torch.device("cuda" if torch.cuda.is_available()
                                  and not training_args.no_cuda else "cpu")
            model.to(device)
            model_str = f'{model_path}-{global_step}' if global_step else model_path
            # Note that DistributedSampler samples
            click.echo(
                f"Generating predictions for model {click.style(model_str, fg='blue')}, "
                f"running on {click.style(str(training_args.device), fg='green')}"
            )
            predictions = generate_predictions(eval_dataset, examples, model,
                                               tokenizer, training_args)
            final_metric = squad_evaluate(examples, predictions)

            if is_wandb_available():
                if training_args.do_train:
                    step = int(
                        global_step) if global_step else trainer.global_step
                else:
                    step = 0
                # for now WANDB cannot 'log back in time'
                wandb.log(final_metric, step=step)
            print(f"GLOBAL STEP: {global_step}")
            result = dict(
                (k + ("_{}".format(global_step) if global_step else '_final'),
                 v) for k, v in final_metric.items())

            logger.info(f"Result for {model_str}: {result}")
            results.update(result)

        # sort results by best
        checkpoint_scores = {
            c.split('_')[-1]: v
            for c, v in results.items()
            if any(c.endswith(digit)
                   for digit in string.digits) and c.startswith('exact')
        }
        sorted_checkpoint_scores = {
            k: v
            for k, v in sorted(checkpoint_scores.items(),
                               key=lambda k_v: k_v[1],
                               reverse=True)
        }
        best_cp = next((c for c, v in sorted_checkpoint_scores.items()
                        if v > results['exact_final']), None)

        if best_cp:
            click.echo(f"Best checkpoint is: {best_cp}")
            # copy over best results
            best_cp_folder = f'checkpoint-{best_cp}'

            click.echo(
                f"Copying over files: from {os.path.join(training_args.output_dir, best_cp_folder)} "
                f"to {training_args.output_dir}")
            files_to_copy = glob.glob(
                os.path.join(training_args.output_dir, best_cp_folder, '*'))
            for file in files_to_copy:
                shutil.copy(file, training_args.output_dir)
        else:
            click.echo("best checkpoint is the last step...")
        # remove 'kek'points
        folders_to_remove = [
            p for p in glob.glob(os.path.join(training_args.output_dir, '*'))
            if os.path.isdir(p)
        ]
        click.echo('Folders to remove: ')
        for folder in folders_to_remove:
            click.echo(f"Removing {folder}")
            shutil.rmtree(folder)
        if training_args.do_train:
            logger.info(results)
            write_json(
                results,
                os.path.join(training_args.output_dir, 'dev-results.json'))
        else:
            write_json(
                predictions,
                get_output_predictions_file_name(
                    data_args.eval_file_path, training_args.output_dir,
                    os.path.basename(
                        os.path.normpath(model_args.model_name_or_path))))