Example #1
0
 def build_with_command(self, ext_builder):
     if CUDA_HOME is None:  # GPU only
         # TODO(guosheng): should we touch a dummy file or add a quick exit
         # method to avoid meaningless process in `load`
         logger.warning(
             "FasterTransformer is not available because CUDA can not be found."
         )
         raise NotImplementedError
     # TODO(guosheng): Multiple -std seems be passed in FasterTransformer,
     # which is not allowed by NVCC. Fix it later.
     self.cmake_args = [f"-DPY_CMD={sys.executable}"]
     # `GetCUDAComputeCapability` is not exposed yet, and detect CUDA/GPU
     # version in cmake file.
     # self.cmake_args += [f"-DSM={self.sm}"] if self.sm is not None else []
     self.cmake_args += [f"-DWITH_GPT=ON"]
     if self.need_parallel:
         self.cmake_args += [f"-DWITH_PARALLEL=ON"]
     try:
         super(FasterTransformerExtension,
               self).build_with_command(ext_builder)
         # FasterTransformer cmake file resets `CMAKE_LIBRARY_OUTPUT_DIRECTORY`
         # to `CMAKE_BINARY_DIR/lib`, thus copy the lib back to `build_ext.build_lib`.
         # Maybe move this copy to CMakeList.
         # `copy_tree` or `copy_file`, boost lib might be included
         ext_builder.copy_tree(os.path.join(ext_builder.build_temp, "lib"),
                               ext_builder.build_lib)
         # TODO(guosheng): Maybe we should delete the build dir especially
         # when it is in the dir of paddlenlp package.
         # os.remove(ext_builder.build_temp)
     except Exception as e:
         logger.warning(
             "FasterTransformer is not available due to build errors.")
         raise e
Example #2
0
 def _get_data(self, mode):
     """ Check and download Dataset """
     dl_paths = {}
     version = self.config.get("version", "3.0.0")
     if version not in ["1.0.0", "2.0.0", "3.0.0"]:
         raise ValueError("Unsupported version: %s" % version)
     dl_paths["version"] = version
     default_root = os.path.join(DATA_HOME, self.__class__.__name__)
     for k, v in self.cnn_dailymail.items():
         dir_path = os.path.join(default_root, k)
         if not os.path.exists(dir_path):
             get_path_from_url(v["url"], default_root, v["md5"])
         unique_endpoints = _get_unique_endpoints(ParallelEnv()
                                                  .trainer_endpoints[:])
         if ParallelEnv().current_endpoint in unique_endpoints:
             file_num = len(os.listdir(os.path.join(dir_path, "stories")))
             if file_num != v["file_num"]:
                 logger.warning(
                     "Number of %s stories is %d != %d, decompress again." %
                     (k, file_num, v["file_num"]))
                 shutil.rmtree(os.path.join(dir_path, "stories"))
                 _decompress(
                     os.path.join(default_root, os.path.basename(v["url"])))
         dl_paths[k] = dir_path
     filename, url, data_hash = self.SPLITS[mode]
     fullname = os.path.join(default_root, filename)
     if not os.path.exists(fullname) or (data_hash and
                                         not md5file(fullname) == data_hash):
         get_path_from_url(url, default_root, data_hash)
     dl_paths[mode] = fullname
     return dl_paths
Example #3
0
    def _get_data(self, mode, **kwargs):
        """Downloads dataset."""
        default_root = os.path.join(DATA_HOME, self.__class__.__name__)
        filename, data_hash, url, zipfile_hash = self.SPLITS[mode]
        fullname = os.path.join(default_root, filename)
        if mode == 'train':
            if not os.path.exists(fullname):
                get_path_from_url(url, default_root, zipfile_hash)
            unique_endpoints = _get_unique_endpoints(
                ParallelEnv().trainer_endpoints[:])
            if ParallelEnv().current_endpoint in unique_endpoints:
                file_num = len(os.listdir(fullname))
                if file_num != len(ALL_LANGUAGES):
                    logger.warning(
                        "Number of train files is %d != %d, decompress again."
                        % (file_num, len(ALL_LANGUAGES)))
                    shutil.rmtree(fullname)
                    _decompress(
                        os.path.join(default_root, os.path.basename(url)))
        else:
            if not os.path.exists(fullname) or (
                    data_hash and not md5file(fullname) == data_hash):
                get_path_from_url(url, default_root, zipfile_hash)

        return fullname
Example #4
0
    def forward(self, inputs, lengths, labels, old_version_labels=None):
        """
        Calculate the crf loss. Let $$ Z(x) = \\sum_{y'}exp(score(x,y')) $$, means the sum of all path scores,
        then we have $$ loss = -logp(y|x) = -log(exp(score(x,y))/Z(x)) = -score(x,y) + logZ(x) $$

        Args:
            inputs (Tensor):
                The input predicted tensor. Its dtype is float32 and has a shape of `[batch_size, sequence_length, num_tags]`.
            lengths (Tensor):
                The input length. Its dtype is int64 and has a shape of `[batch_size]`.
            labels (Tensor) :
                The input label tensor. Its dtype is int64 and has a shape of `[batch_size, sequence_length]`
            old_version_labels (Tensor, optional): Unnecessary parameter for compatibility with older versions. Defaults to ``None``.

        Returns:
            Tensor: The crf loss. Its dtype is float32 and has a shape of `[batch_size]`.
        """
        # Note: When closing to convergence, the loss could be a small negative number. This may caused by underflow when calculating exp in logsumexp.
        #       We add relu here to avoid negative loss. In theory, the crf loss must be greater than or equal to 0, relu will not impact on it.
        if old_version_labels is not None:
            # TODO(qiujinxuan): rm compatibility support after lic.
            labels = old_version_labels
            if not getattr(self, "has_warn", False):
                logger.warning(
                    'Compatibility Warning: The params of LinearChainCrfLoss.forward has been modified. The third param is `labels`, and the fourth is not necessary. Please update the usage.'
                )
                self.has_warn = True
        loss = nn.functional.relu(
            self.crf.forward(inputs, lengths) -
            self.crf.gold_score(inputs, labels, lengths))
        return loss
Example #5
0
    def convert_tokens_to_ids(self, tokens):
        """
        Converts a single token or a sequence of tokens to an index or a
        sequence of indices using the vocab.

        Args:
            tokens (str|List[str]|tuple(str)):
                A single token or a sequence of tokens.

        Returns:
            int|List[int]: The converted token id or token ids.

        Example:
            .. code-block::

                from paddlenlp.transformers import CTRLTokenizer

                tokenizer = CTRLTokenizer.from_pretrained('crtl')
                print(tokenizer.convert_tokens_to_ids(['Welcome', 'to', 'use', 'Padd@@', 'le@@', 'Padd@@', 'le', 'and', 'Padd@@', 'le@@', 'N@@', 'LP']))
                # [41116, 3, 191, 40324, 1162, 40324, 992, 2, 40324, 1162, 633, 11135]

        """
        ids = []
        if isinstance(tokens, str):
            return self._convert_token_to_id(tokens)
        for token in tokens:
            ids.append(self._convert_token_to_id(token))
        if len(ids) > self.max_len:
            logger.warning(
                "Token indices sequence length is longer than the specified maximum "
                " sequence length for this CTRL model ({} > {}). Running this"
                " sequence through the model will result in indexing errors".
                format(len(ids), self.max_len))
        return ids
Example #6
0
def convert_to_fp16(transformer_encoder):
    """ Convert paddle.nn.TransformerEncoder's parameter from float32 to float16

    Args:
        transformer_encoder (obeject, paddle.nn.TransformerEncoder):
            The object to be converted to float16 inplaced, it must be an isinstance
            of paddle.nn.TransformerEncoder.
    """
    if not isinstance(transformer_encoder, paddle.nn.TransformerEncoder):
        logger.warning(
            "transformer_encoder is not isinstance of paddle.nn.TransformerEncoder, return itself with no parameters convertion."
            .format)
        return transformer_encoder
    else:
        encoder_layers = transformer_encoder.layers

        for mod in encoder_layers:
            mod.norm1.weight = transfer_param(mod.norm1.weight,
                                              restore_data=True)
            mod.norm1.bias = transfer_param(mod.norm1.bias,
                                            is_bias=True,
                                            restore_data=True)
            mod.norm2.weight = transfer_param(mod.norm2.weight,
                                              restore_data=True)
            mod.norm2.bias = transfer_param(mod.norm2.bias,
                                            is_bias=True,
                                            restore_data=True)

            mod.linear1.weight = transfer_param(mod.linear1.weight,
                                                restore_data=True)
            mod.linear1.bias = transfer_param(mod.linear1.bias,
                                              is_bias=True,
                                              restore_data=True)

            mod.self_attn.q_proj.weight = transfer_param(
                mod.self_attn.q_proj.weight, restore_data=True)
            mod.self_attn.q_proj.bias = transfer_param(
                mod.self_attn.q_proj.bias, is_bias=True, restore_data=True)
            mod.self_attn.k_proj.weight = transfer_param(
                mod.self_attn.k_proj.weight, restore_data=True)
            mod.self_attn.k_proj.bias = transfer_param(
                mod.self_attn.k_proj.bias, is_bias=True, restore_data=True)
            mod.self_attn.v_proj.weight = transfer_param(
                mod.self_attn.v_proj.weight, restore_data=True)
            mod.self_attn.v_proj.bias = transfer_param(
                mod.self_attn.v_proj.bias, is_bias=True, restore_data=True)
            mod.self_attn.out_proj.weight = transfer_param(
                mod.self_attn.out_proj.weight, restore_data=True)
            mod.self_attn.out_proj.bias = transfer_param(
                mod.self_attn.out_proj.bias, is_bias=True, restore_data=True)

            mod.linear2.weight = transfer_param(mod.linear2.weight,
                                                restore_data=True)
            mod.linear2.bias = transfer_param(mod.linear2.bias,
                                              is_bias=True,
                                              restore_data=True)
        logger.info(
            "Convert transformer_encoder's parameters from float32 to float16 succeessfully."
        )
Example #7
0
def to_distill(self,
               return_qkv=False,
               return_attentions=False,
               return_layer_outputs=False,
               layer_index=-1):
    """
    Can be bound to object with transformer encoder layers, and make model
    expose attributes `outputs.q`, `outputs.k`, `outputs.v`,
    `outputs.scaled_qks`, `outputs.hidden_states`and `outputs.attentions` of
    the object for distillation.
    It could be returned intermediate tensor using in MiniLM and TinyBERT
    strategy.
    """
    logger.warning(
        "`to_distill` is an experimental API and subject to change.")
    MultiHeadAttention._forward = attention_forward
    TransformerEncoderLayer._forward = transformer_encoder_layer_forward
    TransformerEncoder._forward = transformer_encoder_forward
    BertForSequenceClassification._forward = bert_forward

    if return_qkv:
        # forward function of student class should be replaced for distributed training.
        TinyBertForPretraining._forward = minilm_pretraining_forward
        PPMiniLMForSequenceClassification._forward = minilm_pretraining_forward
    else:
        TinyBertForPretraining._forward = tinybert_forward

    def init_func(layer):
        if isinstance(
                layer,
            (MultiHeadAttention, TransformerEncoderLayer, TransformerEncoder,
             TinyBertForPretraining, BertForSequenceClassification,
             PPMiniLMForSequenceClassification)):
            layer.forward = layer._forward
            if isinstance(layer, TransformerEncoder):
                layer.return_layer_outputs = return_layer_outputs
                layer.layer_index = layer_index
            if isinstance(layer, MultiHeadAttention):
                layer.return_attentions = return_attentions
                layer.return_qkv = return_qkv

    for layer in self.children():
        layer.apply(init_func)

    base_model_prefix = self._layers.base_model_prefix if isinstance(
        self, paddle.DataParallel) else self.base_model_prefix

    # For distribute training
    if isinstance(self, paddle.DataParallel):
        if hasattr(self._layers, base_model_prefix):
            self.outputs = getattr(self._layers, base_model_prefix).encoder
        else:
            self.outputs = self._layers.encoder
    else:
        if hasattr(self, base_model_prefix):
            self.outputs = getattr(self, base_model_prefix).encoder
        else:
            self.outputs = self.encoder
    return self
Example #8
0
def enable_faster_encoder(self, use_fp16=False, encoder_lib=None):
    """
    Compiles fusion encoder operator intergrated FasterTransformer using the
    method of JIT(Just-In-Time) and replaces the `forward` function of
    `paddle.nn.TransformerEncoder` and `paddle.nn.TransformerEncoderLayer`
    objects inherited from `self` to support inference using FasterTransformer.

    Examples:

        .. code-block:: python

            from paddlenlp.ops import enable_faster_encoder, disable_faster_encoder

            model.eval()
            model = enable_faster_encoder(model)
            enc_out = model(src, src_mask)
            model = disable_faster_encoder(model)
    """

    def init_func(layer):
        if isinstance(layer, TransformerEncoderLayer):
            is_usable = True
            if layer._config['bias_attr'] == False:
                logger.warning("`False` for paddle.nn.TransformerEncoder's" \
                               " parameter `bias_attr` is not supported in " \
                               "FasterTransformer by now. The original forward" \
                               " will be involved.")
                is_usable = False
            if layer._config['activation'] not in ('relu', 'gelu'):
                logger.warning("Only 'relu' or 'gelu' is supported by now. " \
                                "The original forward will be involved.")
                is_usable = False
            if is_usable:
                layer.forward = layer._ft_forward
        elif isinstance(layer, TransformerEncoder):
            layer.forward = layer._ft_forward
            if use_fp16:
                convert_to_fp16(layer)

    if not self.training:
        try:
            # Pass decoding lib to prevent re-building encoder.
            # Todo: check weather decoding lib have contained encoder or not.
            if encoder_lib is not None:
                if "FasterTransformer" not in LOADED_EXT.keys():
                    ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
                        decoding_lib)
                    LOADED_EXT["FasterTransformer"] = ops
            else:
                load("FasterTransformer", verbose=True)
        except Exception:
            logger.warning(
                "Exception occurs when using FasterEncoder. " \
                "The original forward will be involved. ")
            return self
        for layer in self.children():
            layer.apply(init_func)
    return self
 def add_callback(self, callback):
     cb = callback() if isinstance(callback, type) else callback
     cb_class = callback if isinstance(callback,
                                       type) else callback.__class__
     if cb_class in [c.__class__ for c in self.callbacks]:
         logger.warning(
             f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current"
             + "list of callbacks is\n:" + self.callback_list)
     self.callbacks.append(cb)
Example #10
0
def load(name, build_dir=None, force=False, verbose=False, **kwargs):
    # TODO(guosheng): Need better way to resolve unsupported such as CPU. Currently,
    # raise NotImplementedError and skip `_jit_compile`. Otherwise, `_jit_compile`
    # will output the error to stdout (when verbose is True) and raise `RuntimeError`,
    # which is not friendly for users though no other bad effect.
    if CUDA_HOME is None:
        logger.warning("%s is not available because CUDA can not be found." %
                       name)
        raise NotImplementedError
    if name in LOADED_EXT.keys():
        return LOADED_EXT[name]
    if build_dir is None:
        # Maybe under package dir is better to avoid cmake source path conflict
        # with different source path.
        # build_dir = os.path.join(PPNLP_HOME, 'extenstions')
        build_dir = os.path.join(str(Path(__file__).parent.resolve()),
                                 'extenstions')
    build_base_dir = os.path.abspath(
        os.path.expanduser(os.path.join(build_dir, name)))
    if not os.path.exists(build_base_dir):
        os.makedirs(build_base_dir)

    extension = get_extension_maker(name)(name, **kwargs)
    # Check if 'target' is out-of-date with respect to any file to avoid rebuild
    if isinstance(extension, CMakeExtension):
        # `CppExtention/CUDAExtension `has version manager by `PaddleBuildExtension`
        # Maybe move this to CMakeExtension later.
        # TODO(guosheng): flags/args changes may also trigger build, and maybe
        # need version manager like `PaddleBuildExtension`.
        ext_filename = extension.get_target_filename()
        ext_filepath = os.path.join(build_base_dir, ext_filename)
        if not force:
            ext_sources = extension.sources
            if os.path.exists(ext_filepath) and not newer_group(
                    ext_sources, ext_filepath, 'newer'):
                logger.debug("skipping '%s' extension (up-to-date) build" %
                             name)
                ops = load_op_meta_info_and_register_op(ext_filepath)
                LOADED_EXT[name] = ops
                return LOADED_EXT[name]

    # write setup file and jit compile
    file_path = os.path.join(build_dir, "{}_setup.py".format(name))
    _write_setup_file(name, file_path, build_base_dir, **kwargs)
    _jit_compile(file_path, verbose)
    if isinstance(extension, CMakeExtension):
        # Load a shared library (if exists) only to register op.
        if os.path.exists(ext_filepath):
            ops = load_op_meta_info_and_register_op(ext_filepath)
            LOADED_EXT[name] = ops
            return LOADED_EXT[name]
    else:
        # Import as callable python api
        return _import_module_from_library(name, build_base_dir, verbose)
Example #11
0
 def forward(self, inputs, lengths, labels, old_version_labels=None):
     # Note: When closing to convergence, the loss could be a small negative number. This may caused by underflow when calculating exp in logsumexp.
     #       We add relu here to avoid negative loss. In theory, the crf loss must be greater than or equal to 0, relu will not impact on it.
     if old_version_labels is not None:
         # TODO(qiujinxuan): rm compatibility support after lic.
         labels = old_version_labels
         if not getattr(self, "has_warn", False):
             logger.warning(
                 'Compatibility Warning: The params of LinearChainCrfLoss.forward has been modified. The third param is `labels`, and the fourth is not necessary. Please update the usage.'
             )
             self.has_warn = True
     return nn.functional.relu(
         self.crf.forward(inputs, lengths) -
         self.crf.gold_score(inputs, labels, lengths))
Example #12
0
    def compute(self, lengths, predictions, labels, dummy=None):
        """
        Computes the precision, recall and F1-score for chunk detection.

        Args:
            lengths (Tensor): The valid length of every sequence, a tensor with shape `[batch_size]`
            predictions (Tensor): The predictions index, a tensor with shape `[batch_size, sequence_length]`.
            labels (Tensor): The labels index, a tensor with shape `[batch_size, sequence_length]`.
            dummy (Tensor, optional): Unnecessary parameter for compatibility with older versions with parameters list `inputs`, `lengths`, `predictions`, `labels`. Defaults to None.

        Returns:
            tuple: Returns tuple (`num_infer_chunks, num_label_chunks, num_correct_chunks`).

            With the fields:

            - `num_infer_chunks` (Tensor):
                The number of the inference chunks.

            - `num_label_chunks` (Tensor):
                The number of the label chunks.

            - `num_correct_chunks` (Tensor):
                The number of the correct chunks.
        """
        if dummy is not None:
            # TODO(qiujinxuan): rm compatibility support after lic.
            dummy, lengths, predictions, labels = lengths, predictions, labels, dummy
            if not getattr(self, "has_warn", False):
                logger.warning(
                    'Compatibility Warning: The params of ChunkEvaluator.compute has been modified. The old version is `inputs`, `lengths`, `predictions`, `labels` while the current version is `lengths`, `predictions`, `labels`.  Please update the usage.'
                )
                self.has_warn = True
        labels = labels.numpy()
        predictions = predictions.numpy()
        unpad_labels = [[
            self.id2label_dict[index]
            for index in labels[sent_index][:lengths[sent_index]]
        ] for sent_index in range(len(lengths))]
        unpad_predictions = [[
            self.id2label_dict.get(index, "O")
            for index in predictions[sent_index][:lengths[sent_index]]
        ] for sent_index in range(len(lengths))]

        pred_sum, tp_sum, true_sum = extract_tp_actual_correct(
            unpad_labels, unpad_predictions, self.suffix)
        num_correct_chunks = paddle.to_tensor([tp_sum.sum()])
        num_infer_chunks = paddle.to_tensor([pred_sum.sum()])
        num_label_chunks = paddle.to_tensor([true_sum.sum()])

        return num_infer_chunks, num_label_chunks, num_correct_chunks
Example #13
0
    def on_evaluate(self, args, state, control, metrics, **kwargs):
        metric_to_check = args.metric_for_best_model
        if not metric_to_check.startswith("eval_"):
            metric_to_check = f"eval_{metric_to_check}"
        metric_value = metrics.get(metric_to_check)

        if metric_value is None:
            logger.warning(
                f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping is disabled"
            )
            return

        self.check_metric_value(args, state, control, metric_value)
        if self.early_stopping_patience_counter >= self.early_stopping_patience:
            control.should_training_stop = True
Example #14
0
    def __init__(self, vocab, do_lower_case=False, is_split_into_words=False):
        super(FasterTokenizer, self).__init__()

        try:
            self.mod = importlib.import_module("paddle._C_ops")
        except Exception as e:
            logger.warning(
                "The paddlepaddle version is {paddle.__version__}, not the latest. Please upgrade the paddlepaddle package (>= 2.2.1)."
            )
            self.mod = importlib.import_module("paddle.fluid.core.ops")

        vocab_buffer = to_vocab_buffer(vocab, "vocab")
        self.register_buffer("vocab", vocab_buffer, persistable=True)

        self.do_lower_case = do_lower_case
        self.is_split_into_words = is_split_into_words
Example #15
0
    def __init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler):
        self.callbacks = []
        for cb in callbacks:
            self.add_callback(cb)
        self.model = model
        self.tokenizer = tokenizer
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.train_dataloader = None
        self.eval_dataloader = None

        if not any(
                isinstance(cb, DefaultFlowCallback) for cb in self.callbacks):
            logger.warning(
                "The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n"
                +
                "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of"
                + "callbacks is\n:" + self.callback_list)
Example #16
0
 def init_func(layer):
     if isinstance(layer, TransformerEncoderLayer):
         is_usable = True
         if layer._config['bias_attr'] == False:
             logger.warning("`False` for paddle.nn.TransformerEncoder's" \
                            " parameter `bias_attr` is not supported in " \
                            "FasterTransformer by now. The original forward" \
                            " will be involved.")
             is_usable = False
         if layer._config['activation'] not in ('relu', 'gelu'):
             logger.warning("Only 'relu' or 'gelu' is supported by now. " \
                             "The original forward will be involved.")
             is_usable = False
         if is_usable:
             layer.forward = layer._ft_forward
     elif isinstance(layer, TransformerEncoder):
         layer.forward = layer._ft_forward
         if use_fp16:
             convert_to_fp16(layer)
Example #17
0
def get_data_file(args):
    files = [
        os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
        if (os.path.isfile(os.path.join(args.input_dir, f)) and str(f).endswith(
            "_idx.npz"))
    ]
    files = [x.replace("_idx.npz", "") for x in files]
    if len(files) == 0:
        logger.warning(
            "Not found dataset with name of xxx_ids.npy and xxx_idx.npz! \
            Try to found old compatible xxx_ids.npz file.")
    else:
        return files
    files = [
        os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
        if (os.path.isfile(os.path.join(args.input_dir, f)) and str(f).endswith(
            "_ids.npz"))
    ]
    files = [x.replace("_ids.npz", "") for x in files]
    return files
Example #18
0
    def save_pretrained(self, save_dir):
        """
        Saves model configuration and related resources (model state) as files
        under `save_dir`. The model configuration would be saved into a file named
        "model_config.json", and model state would be saved into a file
        named "model_state.pdparams".

        The `save_dir` can be used in `from_pretrained` as argument value
        of `pretrained_model_name_or_path` to re-load the trained model.

        Args:
            save_dir (str): Directory to save files into.

        Example:
            .. code-block::

                from paddlenlp.transformers import BertForSequenceClassification

                model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
                model.save_pretrained('./trained_model/')
                # reload from save_directory
                model = BertForSequenceClassification.from_pretrained('./trained_model/')
        """
        assert not os.path.isfile(
            save_dir
        ), "Saving directory ({}) should be a directory, not a file".format(
            save_dir)
        os.makedirs(save_dir, exist_ok=True)
        # Save model config
        self.save_model_config(save_dir)
        # Save model
        if paddle.in_dynamic_mode():
            file_name = os.path.join(
                save_dir,
                list(self.resource_files_names.values())[0])
            paddle.save(self.state_dict(), file_name)
        else:
            logger.warning(
                "Save pretrained model only supported dygraph mode for now!")
        # Save resources file
        self.save_resources(save_dir)
Example #19
0
    def convert_tokens_to_ids(self, tokens):
        """
        Converts a single token or a sequence of tokens to an index or a
        sequence of indices using the vocab.

        Args:
            tokens (str|List[str]|tuple(str)):
                A single token or a sequence of tokens.

        Returns:
            int|List[int]: The converted token id or token ids.

        Example:
            .. code-block::

                from paddlenlp.transformers import GPTTokenizer

                tokenizer = GPTTokenizer.from_pretrained('gpt2-medium-en')
                print(tokenizer.convert_tokens_to_ids(['Welcome', 'Ä to', 'Ä use', 'Ä P', 'addle', 'P', 'addle', 'Ä and', 'Ä P', 'addle', 'N', 'LP']))
                # [14618, 284, 779, 350, 37382, 47, 37382, 290, 350, 37382, 45, 19930]
        """

        ids = []
        if isinstance(tokens, str):
            if tokens in self.special_tokens:
                return self.special_tokens[tokens]
            else:
                return self.encoder.get(tokens, 0)
        for token in tokens:
            if token in self.special_tokens:
                ids.append(self.special_tokens[token])
            else:
                ids.append(self.encoder.get(token, 0))
        if len(ids) > self.max_len:
            logger.warning(
                "Token indices sequence length is longer than the specified maximum "
                " sequence length for this OpenAI GPT model ({} > {}). Running this"
                " sequence through the model will result in indexing errors".
                format(len(ids), self.max_len))
        return ids
Example #20
0
    def forward(self, input_ids, token_type_ids=None, position_ids=None):
        if position_ids is None:
            # maybe need use shape op to unify static graph and dynamic graph
            ones = paddle.ones_like(input_ids, dtype="int64")
            seq_length = paddle.cumsum(ones, axis=-1)
            position_ids = seq_length - ones
            position_ids.stop_gradient = True

        input_embedings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = input_embedings + position_embeddings
        if self.type_vocab_size != 0:
            if token_type_ids is None:
                token_type_ids = paddle.zeros_like(input_ids, dtype="int64")
            token_type_embeddings = self.token_type_embeddings(token_type_ids)
            embeddings += token_type_embeddings
        elif token_type_ids is not None:
            logger.warning(
                "There is no need to pass the token type ids to SKEP based on RoBERTa model."
                "The input token type ids will be ignored.")

        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings
Example #21
0
def create_pretrained_dataset(
    args,
    input_path,
    local_rank,
    data_world_rank,
    data_world_size,
    eos_id,
    worker_init=None,
    max_seq_len=1024,
    places=None,
    data_holders=None,
    pipeline_mode=False,
):

    if local_rank == 0:
        start_time = time.time()
        print('> compiling dataset index builder ...')
        from data_tools.dataset_utils import compile_helper
        compile_helper()
        print('>>> done with dataset index builder. Compilation time: {:.3f} '
              'seconds'.format(time.time() - start_time),
              flush=True)

    device_world_size = paddle.distributed.get_world_size()
    device_world_rank = paddle.distributed.get_rank()

    if device_world_size > 1 and local_rank != 0:
        while True:
            try:
                import data_tools.helpers as helpers
                break
            except Exception as e:
                print("> wait for helpers to be compiled!")
                time.sleep(1)

    logger.info(
        "The distributed run, total device num:{}, distinct dataflow num:{}.".
        format(device_world_size, data_world_size))

    assert len(input_path) == 1, "GPT only support one dataset for now."

    input_prefix = input_path[0]

    if os.path.isfile(input_prefix + "_ids.npz"):
        logger.warning(
            "You are using compatible dataset, please make new dataset as the readme!"
        )
        process_datas = np.load(input_prefix + "_ids.npz",
                                mmap_mode="r+",
                                allow_pickle=True)
        sample_ids = process_datas["ids"]
        sample_lens = process_datas["lens"].astype("int32")
    else:
        for suffix in ["_ids.npy", "_idx.npz"]:
            if not os.path.isfile(input_prefix + suffix):
                raise ValueError("File Not found, %s" % (path + suffix))

        sample_ids = np.load(input_prefix + "_ids.npy",
                             mmap_mode="r",
                             allow_pickle=True)
        # All documment ids, extend as 1-D array.

        process_datas = np.load(input_prefix + "_idx.npz")
        # The len(sample_lens) num of docs
        # The sum(sample_lens) should equal len(sample_ids)
        sample_lens = process_datas["lens"]

    splits = get_train_valid_test_split_(args.split, len(sample_lens))
    assert len(sample_lens) >= splits[
        -1], "The document nums should larger than max of splits, but %s < %s" % (
            len(sample_lens), splits[-1])

    def build_dataset(index, name, num_samples):
        dataset = GPTDataset(file_prefix=input_prefix,
                             build_data_file=local_rank == 0,
                             micro_batch_size=args.micro_batch_size,
                             name="gpt_" + name,
                             max_seq_len=max_seq_len,
                             num_samples=num_samples,
                             documents=np.arange(splits[index],
                                                 splits[index + 1]),
                             sample_ids=sample_ids,
                             sample_lens=sample_lens,
                             eos_id=eos_id,
                             seed=args.seed)
        batch_sampler = DistributedBatchSampler(
            dataset,
            batch_size=args.micro_batch_size,
            num_replicas=data_world_size,
            rank=data_world_rank,
            shuffle=False,
            drop_last=True)

        if pipeline_mode:

            def data_gen():
                for data in dataset:
                    yield tuple(
                        [np.expand_dims(np.array(x), axis=0) for x in data])

            data_loader = paddle.fluid.io.DataLoader.from_generator(
                feed_list=data_holders, capacity=70, iterable=False)
            data_loader.set_batch_generator(data_gen, places)
        else:
            data_loader = DataLoader(dataset=dataset,
                                     places=places,
                                     feed_list=data_holders,
                                     batch_sampler=batch_sampler,
                                     num_workers=0,
                                     worker_init_fn=worker_init,
                                     collate_fn=Tuple(Stack(), Stack(),
                                                      Stack(), Stack(),
                                                      Stack()),
                                     return_list=False)
        return data_loader

    # Note, data should be broardcast to all devices.
    # for train, valid, test, the distinct data num is data_world_size
    train_data_loader = build_dataset(
        0, "train", args.micro_batch_size * args.max_steps * data_world_size)
    if pipeline_mode:
        valid_data_loader, test_data_loader = None, None
    else:
        valid_data_loader = build_dataset(
            1, "valid",
            args.micro_batch_size * (args.max_steps // args.eval_freq + 1) *
            args.eval_iters * data_world_size)
        test_data_loader = build_dataset(
            2, "test",
            args.micro_batch_size * args.test_iters * data_world_size)

    return train_data_loader, valid_data_loader, test_data_loader
Example #22
0
from attrdict import AttrDict
import _locale
import jieba
import paddle
from paddlenlp.data import Vocab
from paddlenlp.transformers import position_encoding_init
from paddlenlp.utils.log import logger
from subword_nmt import subword_nmt
import websocket

open_speech = True
try:
    from pyaudio import PyAudio, paInt16
except ImportError as e:
    open_speech = False
    logger.warning("No module named 'pyaudio', so no audio demo.")

import const
from model_demo import SimultaneousTransformerDemo

# By default, the Windows system opens the file with GBK code,
# and the subword_nmt package does not support setting open encoding,
# so it is set to UTF-8 uniformly.
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])

is_win = False
if os.name == 'nt':
    is_win = True


class STACLTokenizer:
def do_train(args):
    # Initialize the paddle and paddle fleet execute environment
    paddle.enable_static()
    fleet.init(is_collective=True)

    # Create the random seed for the worker
    random.seed(args.seed)
    np.random.seed(args.seed)
    paddle.seed(args.seed)
    get_rng_state_tracker().add('global_seed', args.seed)
    get_rng_state_tracker().add('local_seed',
                                args.seed + fleet.worker_index() + 2021)

    assert args.device in [
        "cpu", "gpu", "xpu"
    ], "Invalid device! Available device should be cpu, gpu, or xpu."
    place = paddle.set_device(args.device)

    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()
    assert args.dp_degree * args.sharding_degree * args.mp_degree * args.pp_degree == worker_num, \
        "The product of degree num should be equal to worker_num."

    topo = Topology(device_rank=worker_index,
                    world_size=worker_num,
                    dp_degree=args.dp_degree,
                    pp_degree=args.pp_degree,
                    sharding_degree=args.sharding_degree,
                    mp_degree=args.mp_degree)

    logger.info("The topo of hybrid parallelism:\n{}".format(topo))

    dist_strategy = dist_optimizer(args, topo)

    # Create log write, train results show on last card of pipeline.
    if topo.is_last:
        log_writer_path = os.path.join(
            args.output_dir, "train_log",
            "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format(
                args.model_name_or_path, args.global_batch_size, args.use_amp,
                args.use_recompute, worker_index).lower())
        # if os.path.exists(log_writer_path):
        #     shutil.rmtree(log_writer_path)
        log_writer = LogWriter(log_writer_path)

    # Define the input data in the static mode
    base_class, model_class, criterion_class, tokenizer_class = MODEL_CLASSES[
        args.model_type]
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    # load config in checkpoint
    global_step = 0
    consumed_samples = 0
    checkpoint_dir = os.path.join(args.output_dir, "model_last")
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            with open(os.path.join(checkpoint_dir, "./config.yml"), "r") as f:
                step_config = yaml.load(f, Loader=yaml.FullLoader)
                assert step_config[
                    "global_batch_size"] == args.global_batch_size, "Please ensure checkpoint global batch size is the same. Folder: {}".format(
                        checkpoint_dir)
                consumed_samples = step_config["consumed_samples"]
                global_step = step_config["global_step"]

    data_file = get_train_data_file(args)
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    with paddle.static.program_guard(main_program, startup_program):
        data_holders = create_data_holder(args)
        # 0. input_ids,
        # 1. segment_ids,
        # 2. input_mask,
        # 3. masked_lm_positions,
        # 4. masked_lm_labels,
        # 5. next_sentence_labels

        [
            input_ids, segment_ids, input_mask, masked_lm_positions,
            masked_lm_labels, next_sentence_labels
        ] = data_holders

        tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

        train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
            args,
            data_file,
            tokenizer,
            data_world_size=topo.data_info.size,
            data_world_rank=topo.data_info.rank,
            max_seq_len=args.max_seq_len,
            places=paddle.static.cuda_places(),
            data_holders=data_holders,
            current_step=global_step)
        fleet.init(is_collective=True)

        if args.model_name_or_path in pretrained_models_list:
            model_config = model_class.pretrained_init_configuration[
                args.model_name_or_path]
            if model_config["vocab_size"] % 8 != 0:
                model_config["vocab_size"] += 8 - (model_config["vocab_size"] %
                                                   8)
            model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
            model_config[
                "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
            model = model_class(base_class(**model_config))
        else:
            model, _ = model_class.from_pretrained(
                args.model_name_or_path,
                hidden_dropout_prob=args.hidden_dropout_prob,
                attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            )

        # Create the model for the gpt pretrain
        prediction_scores, seq_relationship_score = model(
            input_ids=input_ids,
            token_type_ids=segment_ids,
            position_ids=None,
            attention_mask=input_mask,
            masked_positions=masked_lm_positions)

        criterion = criterion_class(with_nsp_loss=args.binary_head)
        if args.binary_head:
            lm_loss, sop_loss = criterion(prediction_scores,
                                          seq_relationship_score,
                                          masked_lm_labels,
                                          next_sentence_labels)
            loss = lm_loss + sop_loss
        else:
            loss = criterion(prediction_scores, seq_relationship_score,
                             masked_lm_labels)

        # Create the learning_rate sheduler and optimizer
        if args.decay_steps is None:
            args.decay_steps = args.max_steps

        # lr_scheduler = CosineAnnealingWithWarmupDecay(
        #     max_lr=args.max_lr,
        #     min_lr=args.min_lr,
        #     warmup_step=args.warmup_rate * args.max_steps,
        #     decay_step=args.decay_steps, last_epoch=global_step)

        lr_scheduler = LinearDecayWithWarmup(args.max_lr,
                                             args.max_steps,
                                             args.warmup_rate,
                                             last_epoch=global_step)

        clip = None
        if args.grad_clip > 0:
            clip = paddle.fluid.clip.GradientClipByGlobalNorm(
                clip_norm=args.grad_clip)

        decay_param = [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ]
        logger.info("Using paddle.optimizer.AdamW.")
        optimizer = paddle.optimizer.AdamW(
            learning_rate=lr_scheduler,
            beta1=args.adam_beta1,
            beta2=args.adam_beta2,
            epsilon=args.adam_epsilon,
            grad_clip=clip,
            weight_decay=args.weight_decay,
            apply_decay_param_fun=lambda x: x in decay_param)
        # alias
        optimizer.apply_optimize = optimizer._apply_optimize

        # if args.use_recompute:
        #     dist_strategy.recompute = True
        #     dist_strategy.recompute_configs = {
        #         "checkpoints": model.bert.checkpoints
        #     }

        # Use the fleet api to compile the distributed optimizer
        optimizer = fleet.distributed_optimizer(optimizer,
                                                strategy=dist_strategy)

        optimizer.minimize(loss)
        logger.info(f'final strategy: {fleet._final_strategy()}')
        logger.info("The training meta optimizer is/are %s" %
                    fleet._get_applied_meta_list())

    program_desc_dir = os.path.join(args.output_dir, "program_desc")
    if not os.path.isdir(program_desc_dir):
        os.mkdir(program_desc_dir)

    with open(program_desc_dir + "/main_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(main_program))

    with open(program_desc_dir + "/startup_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(startup_program))

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)

    test_program = main_program.clone(for_test=True)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " %
                    args.model_name_or_path)
        dygrah_path = os.path.join(args.model_name_or_path,
                                   "model_state.pdparams")
        static_path = os.path.join(args.model_name_or_path, "static_vars")

        flag_loaded = False
        if os.path.exists(static_path):
            if args.mp_degree > 1:
                logger.warning("MP should init with dygraph params")
            else:
                logger.info("Loading parameters from %s" % static_path)
                paddle.static.load(main_program, static_path, exe)
                flag_loaded = True

        if not flag_loaded and os.path.exists(dygrah_path):
            if args.sharding_degree > 1:
                logger.warning("Sharding should init with static vars")
            else:
                logger.info("Loading parameters from %s" % dygrah_path)
                init_static_with_params(
                    model, paddle.load(dygrah_path, return_numpy=True), topo,
                    main_program)
                flag_loaded = True

        if not flag_loaded:
            logger.error("No checkpoint load.")

    # load checkpoint vars
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            paddle.static.load(main_program,
                               os.path.join(checkpoint_dir, "static_vars"),
                               exe)

    fetch_loss_vars = collections.OrderedDict()
    fetch_other_vars = collections.OrderedDict()
    fetch_loss_vars["loss"] = loss
    if args.binary_head:
        fetch_loss_vars["lm_loss"] = lm_loss
        fetch_loss_vars["sop_loss"] = sop_loss

    fetch_other_vars["learning_rate"] = main_program.global_block(
    ).vars["learning_rate_0"]

    additional_vars = collections.OrderedDict()
    if args.use_amp:
        for key in ["loss_scaling", "num_good_steps", "num_bad_steps"]:
            additional_vars[key] = main_program.global_block().vars[key + "_0"]

    tic_train = time.time()
    while True:
        fetchs = []
        fetchs_keys = []
        if topo.is_last:
            fetchs = list(fetch_loss_vars.values()) + list(
                fetch_other_vars.values()) + list(additional_vars.values())
            fetchs_keys = list(fetch_loss_vars.keys()) + list(
                fetch_other_vars.keys()) + list(additional_vars.keys())

        # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
        # many times. and start a new random dataloader.
        valid_data_loader = valid_data_loader()
        test_data_loader = test_data_loader()

        for step, batch in enumerate(train_data_loader()):
            ret = exe.run(main_program,
                          feed=batch,
                          fetch_list=fetchs,
                          use_program_cache=True)
            # Skip for accumulate_steps in global step
            if (step + 1) % args.accumulate_steps != 0:
                continue
            global_step += 1
            # In the new 2.0 api, must call this function to change the learning_rate
            lr_scheduler.step()

            if global_step % args.logging_freq == 0:
                if topo.is_last:
                    res = collections.defaultdict(float)
                    for k, v in zip(fetchs_keys, ret):
                        res[k] = v[0]

                    speed = args.logging_freq / (time.time() - tic_train)

                    loss_info = "loss: %.6f, lm_loss: %.6f, sop_loss: %.6f"

                    loss_info = ", ".join([
                        "{}: {:.6f}".format(k, res[k])
                        for k in fetch_loss_vars.keys()
                    ])

                    common_loginfo = "global step %d, %s, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e" % (
                        global_step, loss_info, speed,
                        speed * args.global_batch_size, res["learning_rate"])
                    additional_loginfo = ", ".join([
                        "{}: {}".format(k, res[k])
                        for k in additional_vars.keys()
                    ])
                    if additional_loginfo:
                        common_loginfo += ", " + additional_loginfo
                    logger.info(common_loginfo)
                    for k, v in res.items():
                        log_writer.add_scalar(k, v, global_step)

                tic_train = time.time()

            #if args.check_accuracy:
            #    if global_step >= args.max_steps:
            #        return
            #    else:
            #        continue

            if global_step % args.eval_freq == 0:
                # TODO, check the input data of validation
                eval_fetch = collections.OrderedDict()
                if topo.is_last:
                    eval_fetch["loss"] = loss
                    if args.binary_head:
                        eval_fetch["lm_loss"] = lm_loss
                        eval_fetch["sop_loss"] = sop_loss

                run_evaluate(valid_data_loader, exe, test_program,
                             args.eval_iters, log_writer, global_step, args,
                             topo.is_last, eval_fetch, "valid")
                tic_train = time.time()

            if global_step % args.save_steps == 0 or global_step >= args.max_steps:
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                logger.debug("saving models to {}".format(output_dir))
                save_persistables(exe, os.path.join(output_dir, "static_vars"),
                                  main_program)
                if global_step == args.save_steps:
                    model.init_config["init_args"][0].init_config.pop(
                        "topo", None)
                model.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                tic_train = time.time()

            if global_step % args.checkpoint_steps == 0:
                output_dir = os.path.join(args.output_dir, "model_last")
                if worker_index == 0:
                    if not os.path.exists(output_dir):
                        os.mkdir(output_dir)
                    output_dir_bak = os.path.join(args.output_dir,
                                                  "model_last_bak")
                    if os.path.exists(output_dir):
                        if os.path.exists(output_dir_bak):
                            shutil.rmtree(output_dir_bak)
                        shutil.move(output_dir, output_dir_bak)
                        os.mkdir(output_dir)

                    step_config = {
                        "model_name": args.model_name_or_path,
                        "global_step": global_step,
                        "global_batch_size": args.global_batch_size,
                        "consumed_samples":
                        global_step * args.global_batch_size,
                    }

                    with open(os.path.join(output_dir, "config.yml"),
                              "w") as f:
                        yaml.dump(step_config,
                                  f,
                                  encoding='utf-8',
                                  allow_unicode=True)

                fleet.barrier_worker()

                logger.debug("saving models to {}".format(output_dir))
                if args.sharding_degree <= 1:
                    # Save on the first worker by default.
                    if worker_index == 0:
                        paddle.static.save(
                            main_program,
                            os.path.join(output_dir, "static_vars"))
                else:
                    # Use save_persistables in sharding, but more slower
                    save_persistables(exe,
                                      os.path.join(output_dir, "static_vars"),
                                      main_program)

            if global_step >= args.max_steps:
                eval_fetch = collections.OrderedDict()
                if topo.is_last:
                    eval_fetch["loss"] = loss
                    if args.binary_head:
                        eval_fetch["lm_loss"] = lm_loss
                        eval_fetch["sop_loss"] = sop_loss

                run_evaluate(test_data_loader, exe, test_program,
                             args.test_iters, log_writer, global_step, args,
                             topo.is_last, eval_fetch, "test")
                del train_data_loader
                return
Example #24
0
def parse_args(MODEL_CLASSES):
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " +
        ", ".join(MODEL_CLASSES.keys()),
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(
            sum([
                list(classes[-1].pretrained_init_configuration.keys())
                for classes in MODEL_CLASSES.values()
            ], [])),
    )

    # Train I/O config
    parser.add_argument(
        "--input_dir",
        default=None,
        type=str,
        required=True,
        help="The input directory where the data will be read from.",
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the training logs and checkpoints will be written."
    )
    parser.add_argument("--split",
                        type=str,
                        default='949,50,1',
                        help="Train/valid/test data split.")

    parser.add_argument("--max_seq_len",
                        type=int,
                        default=1024,
                        help="Max sequence length.")

    parser.add_argument(
        "--global_batch_size",
        default=None,
        type=int,
        help=
        "Global batch size for all training process. None for not check the size is valid. If we only use data parallelism, it should be device_num * micro_batch_size."
    )

    parser.add_argument(
        "--local_batch_size",
        default=None,
        type=int,
        help=
        "Global batch size for all training process. None for not check the size is valid. If we only use data parallelism, it should be device_num * micro_batch_size."
    )

    parser.add_argument(
        "--micro_batch_size",
        default=8,
        type=int,
        help="Batch size per device for one step training.",
    )

    # Default training config
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--grad_clip",
                        default=0.0,
                        type=float,
                        help="Grad clip for the parameter.")
    parser.add_argument("--max_lr",
                        default=0.00015,
                        type=float,
                        help="The initial max learning rate for Adam.")
    parser.add_argument("--min_lr",
                        default=1e-5,
                        type=float,
                        help="The initial min learning rate for Adam.")
    parser.add_argument(
        "--warmup_rate",
        default=0.01,
        type=float,
        help="Linear warmup over warmup_steps for learing rate.")

    # Adam optimizer config
    parser.add_argument(
        "--adam_beta1",
        default=0.9,
        type=float,
        help=
        "The beta1 for Adam optimizer. The exponential decay rate for the 1st moment estimates."
    )
    parser.add_argument(
        "--adam_beta2",
        default=0.999,
        type=float,
        help=
        "The bate2 for Adam optimizer. The exponential decay rate for the 2nd moment estimates."
    )
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")

    # Training steps config
    parser.add_argument(
        "--num_train_epochs",
        default=1,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--max_steps",
        default=500000,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument("--save_steps",
                        type=int,
                        default=500,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--decay_steps",
        default=360000,
        type=int,
        help=
        "The steps use to control the learing rate. If the step > decay_steps, will use the min_lr."
    )
    parser.add_argument("--logging_freq",
                        type=int,
                        default=1,
                        help="Log every X updates steps.")
    parser.add_argument("--eval_freq",
                        type=int,
                        default=500,
                        help="Evaluate for every X updates steps.")
    parser.add_argument("--eval_iters",
                        type=int,
                        default=10,
                        help="Evaluate the model use X steps data.")

    # Config for 4D Parallelism

    parser.add_argument(
        "--sharding_degree",
        type=int,
        default=1,
        help="Sharding degree. Share the parameters to many cards.")

    parser.add_argument("--dp_degree",
                        type=int,
                        default=1,
                        help="Data Parallelism degree.")
    parser.add_argument(
        "--mp_degree",
        type=int,
        default=1,
        help=
        "Model Parallelism degree. Spliting the linear layers to many cards.")
    parser.add_argument(
        "--pp_degree",
        type=int,
        default=1,
        help=
        "Pipeline Parallelism degree.  Spliting the the model layers to different parts."
    )
    parser.add_argument("--use_recompute",
                        type=str2bool,
                        nargs='?',
                        const=False,
                        help="Using the recompute to save the memory.")

    parser.add_argument(
        "--recompute_partition",
        type=str2bool,
        nargs='?',
        const=False,
        help=
        "use recompute_partition to support mp partition when use_recompute is True ."
    )

    parser.add_argument(
        "--recompute_offload",
        type=str2bool,
        nargs='?',
        const=False,
        help=
        "use recompute_offload to save the memory by offload when use_recompute is True ."
    )

    parser.add_argument(
        "--resume_dir",
        default="",
        type=str,
        required=True,
        help="The resume directory where the checkpoint will be resume.")

    # Pure FP16 config
    parser.add_argument("--use_pure_fp16",
                        type=str2bool,
                        nargs='?',
                        const=False,
                        help="Enable pure fp16 precision training.")

    parser.add_argument(
        "--scale_loss",
        type=float,
        default=32768,
        help=
        "The value of scale_loss for fp16. This is only used for AMP training."
    )

    parser.add_argument("--hidden_dropout_prob",
                        type=float,
                        default=0.1,
                        help="The hidden dropout prob.")

    parser.add_argument("--attention_probs_dropout_prob",
                        type=float,
                        default=0.1,
                        help="The attention probs dropout prob.")

    # MOE config
    parser.add_argument("--num_experts",
                        type=int,
                        default=1,
                        help="number of experts per worker")

    parser.add_argument("--top_k",
                        type=int,
                        default=2,
                        help="top_k for moe gate")

    parser.add_argument("--expert_mode",
                        type=str2bool,
                        nargs='?',
                        const=False,
                        help="Enable Moe mode.")

    parser.add_argument(
        "--balance_loss_weight",
        default=1.0,
        type=float,
        help=
        "The auxiliary loss generated by gate strategy to help balance experts."
    )

    parser.add_argument("--gate",
                        type=str,
                        default="gshard",
                        choices=["naive", "gshard", "switch"],
                        help="select naive, gshard, switch gate strategy.")

    # Other config
    parser.add_argument("--seed",
                        type=int,
                        default=1234,
                        help="Random seed for initialization")
    parser.add_argument("--check_accuracy",
                        type=str2bool,
                        nargs='?',
                        const=False,
                        help="Check accuracy for training process.")
    parser.add_argument("--device",
                        type=str,
                        default="gpu",
                        choices=["cpu", "gpu", "xpu"],
                        help="select cpu, gpu, xpu devices.")
    parser.add_argument("--lr_decay_style",
                        type=str,
                        default="cosine",
                        choices=["cosine", "none"],
                        help="Learning rate decay style.")

    args = parser.parse_args()
    args.test_iters = args.eval_iters * 10

    # process batch size
    process_batch_size(args)

    if args.check_accuracy:
        if args.hidden_dropout_prob != 0:
            args.hidden_dropout_prob = .0
            logger.warning(
                "The hidden_dropout_prob should set to 0 for accuracy checking."
            )
        if args.attention_probs_dropout_prob != 0:
            args.attention_probs_dropout_prob = .0
            logger.warning(
                "The attention_probs_dropout_prob should set to 0 for accuracy checking."
            )

    logger.info('{:20}:{}'.format("paddle commit id", paddle.version.commit))
    for arg in vars(args):
        logger.info('{:20}:{}'.format(arg, getattr(args, arg)))

    return args
Example #25
0
def do_train():
    parser = PdArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    paddle.set_device(training_args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # set_seed(args)
    data_args.dataset = data_args.dataset.strip()
    if data_args.dataset not in ALL_DATASETS:
        raise ValueError("Not found dataset {}".format(data_args.dataset))

    # Use yaml config to rewrite all args.
    config = ALL_DATASETS[data_args.dataset]
    for args in (model_args, data_args, training_args):
        for arg in vars(args):
            if arg in config.keys():
                setattr(args, arg, config[arg])

    training_args.per_device_train_batch_size = config["batch_size"]
    training_args.per_device_eval_batch_size = config["batch_size"]

    dataset_config = data_args.dataset.split(" ")
    raw_datasets = load_dataset(
        dataset_config[0],
        None if len(dataset_config) <= 1 else dataset_config[1],
    )

    data_args.label_list = getattr(raw_datasets['train'], "label_list", None)
    num_classes = 1 if raw_datasets["train"].label_list == None else len(
        raw_datasets['train'].label_list)

    # Define tokenizer, model, loss function.
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path, num_classes=num_classes)
    loss_fct = nn.loss.CrossEntropyLoss(
    ) if data_args.label_list else nn.loss.MSELoss()

    # Define dataset pre-process function
    if "clue" in data_args.dataset:
        trans_fn = partial(clue_trans_fn, tokenizer=tokenizer, args=data_args)
    else:
        trans_fn = partial(seq_trans_fn, tokenizer=tokenizer, args=data_args)

    # Define data collector
    batchify_fn = defaut_collator(tokenizer, data_args)

    # Dataset pre-process
    train_dataset = raw_datasets["train"].map(trans_fn)
    eval_dataset = raw_datasets["dev"].map(trans_fn)
    test_dataset = raw_datasets["test"].map(trans_fn)

    # Define the metrics of tasks.
    def compute_metrics(p):
        preds = p.predictions[0] if isinstance(p.predictions,
                                               tuple) else p.predictions

        preds = paddle.to_tensor(preds)
        label = paddle.to_tensor(p.label_ids)

        probs = F.softmax(preds, axis=1)
        metric = Accuracy()
        metric.reset()
        result = metric.compute(preds, label)
        metric.update(result)
        accu = metric.accumulate()
        metric.reset()
        return {"accuracy": accu}

    trainer = Trainer(
        model=model,
        criterion=loss_fct,
        args=training_args,
        data_collator=batchify_fn,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    # Log model and data config
    trainer.print_config(model_args, "Model")
    trainer.print_config(data_args, "Data")

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint

    # Training
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
    metrics = train_result.metrics
    trainer.save_model()  # Saves the tokenizer too for easy upload
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

    # Evaluate and tests model
    eval_metrics = trainer.evaluate()
    trainer.log_metrics("eval", eval_metrics)

    test_ret = trainer.predict(test_dataset)
    trainer.log_metrics("test", test_ret.metrics)
    if test_ret.label_ids is None:
        paddle.save(
            test_ret.predictions,
            os.path.join(training_args.output_dir, "test_results.pdtensor"),
        )

    # export inference model
    input_spec = [
        paddle.static.InputSpec(shape=[None, None],
                                dtype="int64"),  # input_ids
        paddle.static.InputSpec(shape=[None, None],
                                dtype="int64")  # segment_ids
    ]
    trainer.export_model(input_spec=input_spec,
                         load_best_model=True,
                         output_dir=model_args.export_model_dir)
Example #26
0
def do_train():
    parser = PdArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    paddle.set_device(training_args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(
            training_args.output_dir
    ) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # set_seed(args)
    data_args.dataset = data_args.dataset.strip()
    if data_args.dataset not in ALL_DATASETS:
        raise ValueError("Not found dataset {}".format(data_args.dataset))

    # Use yaml config to rewrite all args.
    config = ALL_DATASETS[data_args.dataset]
    for args in (model_args, data_args, training_args):
        for arg in vars(args):
            if arg in config.keys():
                setattr(args, arg, config[arg])

    training_args.per_device_train_batch_size = config["batch_size"]
    training_args.per_device_eval_batch_size = config["batch_size"]

    dataset_config = data_args.dataset.split(" ")
    all_ds = load_dataset(
        dataset_config[0],
        None if len(dataset_config) <= 1 else dataset_config[1],
    )

    label_list = getattr(all_ds['train'], "label_list", None)
    data_args.label_list = label_list
    data_args.ignore_label = -100
    data_args.no_entity_id = len(data_args.label_list) - 1

    num_classes = 1 if all_ds["train"].label_list == None else len(
        all_ds['train'].label_list)

    # Define tokenizer, model, loss function.
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    model = AutoModelForTokenClassification.from_pretrained(
        model_args.model_name_or_path, num_classes=num_classes)

    class criterion(nn.Layer):
        def __init__(self):
            super(criterion, self).__init__()
            self.loss_fn = paddle.nn.loss.CrossEntropyLoss(
                ignore_index=data_args.ignore_label)

        def forward(self, *args, **kwargs):
            return paddle.mean(self.loss_fn(*args, **kwargs))

    loss_fct = criterion()

    # Define dataset pre-process function
    trans_fn = partial(ner_trans_fn, tokenizer=tokenizer, args=data_args)

    # Define data collector
    batchify_fn = ner_collator(tokenizer, data_args)

    # Dataset pre-process
    train_dataset = all_ds["train"].map(trans_fn)
    eval_dataset = all_ds["dev"].map(trans_fn)
    test_dataset = all_ds["test"].map(trans_fn)

    # Define the metrics of tasks.
    # Metrics
    metric = load_metric("seqeval")

    def compute_metrics(p):
        predictions, labels = p
        predictions = np.argmax(predictions, axis=2)

        # Remove ignored index (special tokens)
        true_predictions = [[
            label_list[p] for (p, l) in zip(prediction, label) if l != -100
        ] for prediction, label in zip(predictions, labels)]
        true_labels = [[
            label_list[l] for (p, l) in zip(prediction, label) if l != -100
        ] for prediction, label in zip(predictions, labels)]
        results = metric.compute(predictions=true_predictions,
                                 references=true_labels)
        return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"],
        }

    trainer = Trainer(
        model=model,
        criterion=loss_fct,
        args=training_args,
        data_collator=batchify_fn,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    # Log model and data config
    trainer.print_config(model_args, "Model")
    trainer.print_config(data_args, "Data")

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint

    # Training
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
    metrics = train_result.metrics
    trainer.save_model()  # Saves the tokenizer too for easy upload
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

    # Evaluate and tests model
    eval_metrics = trainer.evaluate()
    trainer.log_metrics("eval", eval_metrics)

    test_ret = trainer.predict(test_dataset)
    trainer.log_metrics("test", test_ret.metrics)
    if test_ret.label_ids is None:
        paddle.save(
            test_ret.predictions,
            os.path.join(training_args.output_dir, "test_results.pdtensor"),
        )

    # export inference model
    input_spec = [
        paddle.static.InputSpec(shape=[None, None],
                                dtype="int64"),  # input_ids
        paddle.static.InputSpec(shape=[None, None],
                                dtype="int64")  # segment_ids
    ]
    trainer.export_model(input_spec=input_spec,
                         load_best_model=True,
                         output_dir=model_args.export_model_dir)
Example #27
0
def do_train(args):
    paddle.set_device(args.device)
    nranks = paddle.distributed.get_world_size()
    strategy = fleet.DistributedStrategy()
    strategy.hybrid_configs = {
        "dp_degree": args.dp_degree,
        "mp_degree": args.mp_degree,
        "pp_degree": args.pp_degree,
        "sharding_degree": args.sharding_degree
    }

    accumulate_steps = args.local_batch_size // args.micro_batch_size
    strategy.pipeline_configs = {
        "accumulate_steps": accumulate_steps,
        "micro_batch_size": args.micro_batch_size
    }

    # set control in tensor parallel
    strategy.tensor_parallel_configs = {"tensor_init_seed": args.seed}

    fleet.init(is_collective=True, strategy=strategy)

    # obtain rank message of hybrid parallel
    hcg = fleet.get_hybrid_communicate_group()
    global_rank = hcg.get_global_rank()
    mp_rank = hcg.get_model_parallel_rank()
    pp_rank = hcg.get_stage_id()
    dp_rank = hcg.get_data_parallel_rank()
    sharding_rank = hcg.get_sharding_parallel_rank()

    # sharding stage2/3 not support hybrid parallel
    if args.sharding_stage in [2, 3]:
        assert args.dp_degree == args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support hybrid parallel later"

    sharding_size = hcg.get_sharding_parallel_world_size()
    data_world_rank = dp_rank * sharding_size + sharding_rank
    data_world_size = args.dp_degree * args.sharding_degree
    local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))

    # seed control in hybrid parallel
    set_hyrbid_parallel_seed(args.seed, data_world_rank, mp_rank, pp_rank)

    default_global_tokens_num = args.global_batch_size * args.max_seq_len

    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    # Define log writer
    log_writer_path = os.path.join(
        args.output_dir, "train_log",
        "{}_globalbsz_{}_pure_fp16_{}_recompute_{}_card_{}".format(
            args.model_name_or_path, args.global_batch_size, args.use_pure_fp16,
            False, global_rank).lower())

    if os.path.exists(log_writer_path):
        import shutil
        shutil.rmtree(log_writer_path)

    log_writer = LogWriter(log_writer_path)

    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    if args.model_name_or_path in pretrained_models_list:
        model_config = model_class.pretrained_init_configuration[
            args.model_name_or_path]
        model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
        model_config[
            "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob

        model_config['num_partitions'] = args.mp_degree
        model_config['use_recompute'] = args.use_recompute
        if args.pp_degree == 1:
            model = GPTForPretraining(GPTModel(**model_config))
        else:
            model_config['topology'] = hcg.topology()
            model = GPTForPretrainingPipe(**model_config)
    else:
        model = GPTForPretraining.from_pretrained(
            args.model_name_or_path,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob)

    # Create the critrion for the gpt model
    criterion = GPTPretrainingCriterion()

    if args.decay_steps is None:
        args.decay_steps = args.max_steps
    warmup_step = args.warmup_rate * args.decay_steps

    lr_scheduler = None

    if args.lr_decay_style == "none":
        lr_scheduler = None
    elif args.lr_decay_style == "cosine":
        lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
            max_lr=args.max_lr,
            min_lr=args.min_lr,
            warmup_step=warmup_step,
            decay_step=args.decay_steps)

    clip = None
    if args.grad_clip > 0:
        clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=args.grad_clip)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]

    if args.sharding_stage == 1 and args.sharding_degree > 1:
        optimizer = DygraphShardingOptimizer(
            hcg=fleet.get_hybrid_communicate_group(),
            user_defined_strategy=strategy,
            params=model.parameters(),
            inner_optimizer_class=paddle.optimizer.AdamW,
            learning_rate=lr_scheduler
            if lr_scheduler is not None else args.max_lr,
            beta1=args.adam_beta1,
            beta2=args.adam_beta2,
            epsilon=args.adam_epsilon,
            weight_decay=args.weight_decay,
            grad_clip=clip,
            apply_decay_param_fun=lambda x: x in decay_params)
    else:
        optimizer = paddle.optimizer.AdamW(
            learning_rate=lr_scheduler
            if lr_scheduler is not None else args.max_lr,
            beta1=args.adam_beta1,
            beta2=args.adam_beta2,
            epsilon=args.adam_epsilon,
            parameters=model.parameters(),
            weight_decay=args.weight_decay,
            grad_clip=clip,
            apply_decay_param_fun=lambda x: x in decay_params,
            # TODO: remove 'multi_precision' in definition of optimizer
            # and add it to 'paddle.amp.decorate'
            multi_precision=args.use_pure_fp16)

    if args.use_pure_fp16:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
        # level O2 means converting the network to FP16
        if args.sharding_stage not in [2, 3]:
            scaler = fleet.distributed_scaler(scaler)
        model = paddle.amp.decorate(
            models=model, level='O2', save_dtype='float32')

    # wrap sharding stage2/3 and add collective group
    # TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature
    if args.sharding_stage in [2, 3]:
        scaler = scaler if args.use_pure_fp16 else None
        model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler,
                                                     args.sharding_offload)

    elif paddle.distributed.get_world_size() > 1:
        model = fleet.distributed_model(model)
        optimizer = fleet.distributed_optimizer(optimizer)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " % args.model_name_or_path)
        opt_path = os.path.join(args.model_name_or_path, "model_state.pdopt")
        if os.path.exists(opt_path):
            opt_dict = paddle.load(opt_path)
            optimizer.set_state_dict(opt_dict)
        else:
            logger.warning("No optimizer checkpoint file found in %s." %
                           opt_path)

    global_step = 0
    tic_train = time.time()
    for epoch in range(args.num_train_epochs):
        files = get_train_data_file(args)
        files.sort()
        num_files = len(files)
        for f_id in range(num_files):
            data_file = files[f_id]
            train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
                args, [data_file],
                local_rank=local_rank,
                data_world_size=data_world_size,
                data_world_rank=data_world_rank,
                eos_id=tokenizer.eos_token_id)
            # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
            # many times. and start a new random dataloader.
            valid_data_loader = valid_data_loader()
            test_data_loader = test_data_loader()

            # time count
            train_reader_cost = 0.0
            train_run_cost = 0.0
            reader_start = time.time()
            for step, batch in enumerate(train_data_loader()):
                train_reader_cost += time.time() - reader_start
                train_start = time.time()

                global_step += 1
                tokens, loss_mask, position_ids, labels = batch

                loss_mask.stop_gradient = True
                labels.stop_gradient = True
                position_ids.stop_gradient = True

                if args.pp_degree == 1:
                    # In ParallelMode of DataParallel, 'no_sync' can be used for improving
                    # performance of model by gradient accumulation.
                    loss = 0.0
                    for i in range(accumulate_steps):
                        start_index = i * args.micro_batch_size
                        end_index = start_index + args.micro_batch_size
                        with paddle.amp.auto_cast(
                                args.use_pure_fp16,
                                custom_black_list=[
                                    "reduce_sum",
                                    "c_softmax_with_cross_entropy",
                                    "elementwise_div"
                                ],
                                level='O2'):
                            preds = model(
                                tokens[start_index:end_index, :],
                                position_ids[start_index:end_index, :])
                            loss_mbs = criterion(
                                preds, labels[start_index:end_index, :],
                                loss_mask[start_index:end_index, :])
                        loss_mbs = loss_mbs / accumulate_steps
                        if args.use_pure_fp16:
                            scaler.scale(loss_mbs).backward()
                        else:
                            loss_mbs.backward()
                        loss = loss + loss_mbs

                    if args.use_pure_fp16:
                        if args.sharding_stage in [2, 3]:
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            scaler.minimize(optimizer, loss)
                    else:
                        optimizer.step()

                    if lr_scheduler is not None:
                        lr_scheduler.step()

                    optimizer.clear_grad()

                else:
                    data = [(tokens, position_ids), (labels, loss_mask)]
                    with paddle.amp.auto_cast(
                            args.use_pure_fp16,
                            custom_black_list=[
                                "reduce_sum", "c_softmax_with_cross_entropy",
                                "elementwise_div"
                            ],
                            level='O2'):
                        loss = model.train_batch(
                            data,
                            optimizer=optimizer,
                            lr_scheduler=lr_scheduler,
                            scaler=scaler if args.use_pure_fp16 else None)

                # Sync for profile time, delete it may be a little faster
                paddle.device.cuda.synchronize()
                train_run_cost += time.time() - train_start
                # Profile for model benchmark
                profiler.add_profiler_step(args.profiler_options)

                if global_step % args.logging_freq == 0:
                    avg_loss = loss.numpy()
                    speed = args.logging_freq / (
                        train_reader_cost + train_run_cost)
                    avg_reader_cost = train_reader_cost / args.logging_freq

                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %.9f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, speed: %.2f step/s, ips: %.0f tokens/s, ips_per_card: %.0f tokens/s, learning rate: %.5e"
                        % (global_step, epoch, step, avg_loss, avg_reader_cost,
                           1. / speed, speed, speed * default_global_tokens_num,
                           speed * default_global_tokens_num / nranks,
                           optimizer.get_lr()))
                    log_writer.add_scalar("loss", float(loss), global_step)
                    log_writer.add_scalar("learning_rate",
                                          optimizer.get_lr(), global_step)

                    tic_train = time.time()
                    train_reader_cost = 0.0
                    train_run_cost = 0.0

                if args.check_accuracy:
                    if global_step >= args.max_steps:
                        return
                    else:
                        continue

                if global_step % args.eval_freq == 0:
                    # Since the valid data broardcast to all devices, we do evaluate on all device.
                    run_evaluate(args, valid_data_loader, model, criterion,
                                 args.eval_iters, log_writer, global_step,
                                 epoch, "valid")

                # TODO: 1. merge paramters while saving model. 2. ensure that the model is saved and loaded correctly
                # only dp_rank = 0 save model
                if (global_step % args.save_steps == 0 or
                        global_step >= args.max_steps) and dp_rank == 0:

                    model_to_save = model._layers if paddle.distributed.get_world_size(
                    ) > 1 and args.sharding_stage not in [2, 3] else model
                    output_dir = os.path.join(args.output_dir,
                                              "step_%d" % global_step)
                    os.makedirs(output_dir, exist_ok=True)

                    logger.info("Save model to %s" % output_dir)

                    if args.pp_degree > 1:
                        if mp_rank == 0 and sharding_rank == 0 and pp_rank == 0:
                            tokenizer.save_pretrained(output_dir)
                        model_to_save.save_state_dict(output_dir)
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(
                                output_dir,
                                "model_state_mp_{:0>2d}_sharding_{:0>2d}_pp_{:0>2d}.pdopt".
                                format(mp_rank, sharding_rank, pp_rank)))
                    else:
                        if args.sharding_stage == 3:
                            # If parameter need to convert to cpu, please add convert2cpu=True
                            model_to_save.get_all_parameters(convert2cpu=False)
                        if mp_rank == 0 and sharding_rank == 0:
                            tokenizer.save_pretrained(output_dir)
                        model_to_save.save_pretrained(output_dir)
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(
                                output_dir,
                                "model_state_mp_{:0>2d}_sharding_{:0>2d}.pdopt".
                                format(mp_rank, sharding_rank)))

                if global_step >= args.max_steps:
                    run_evaluate(args, test_data_loader, model, criterion,
                                 args.test_iters, log_writer, global_step,
                                 epoch, "test")
                    logger.info("The training process is complete.")
                    del train_data_loader
                    return

                reader_start = time.time()

            del train_data_loader
Example #28
0
def do_train(args):
    paddle.set_device(args.device)

    worker_index = paddle.distributed.get_rank()
    worker_num = paddle.distributed.get_world_size()
    local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))

    if worker_num > 1:
        paddle.distributed.init_parallel_env()

    if args.dp_degree * args.sharding_degree == 1:
        args.dp_degree = worker_num
        args.sharding_degree = 1

    args_post_process(args, worker_num)

    logger.info('{:20}:{}'.format("paddle commit id", paddle.version.commit))
    for arg in vars(args):
        logger.info('{:20}:{}'.format(arg, getattr(args, arg)))

    strategy = fleet.DistributedStrategy()
    strategy.hybrid_configs = {
        "dp_degree": args.dp_degree,
        "mp_degree": 1,
        "pp_degree": 1,
        "sharding_degree": 1
    }

    fleet.init(is_collective=True, strategy=strategy)
    hcg = fleet.get_hybrid_communicate_group()

    # Create the random seed for the worker
    set_seed(args)

    assert args.dp_degree * args.sharding_degree == worker_num, \
        "The product of degree num should be equal to worker_num."

    # Create log write,
    log_writer = None
    if worker_index == 0:
        log_writer = LogWriter(os.path.join(args.output_dir, default_logdir()))

    # Define the input data in the static mode
    base_class, model_class, criterion_class, tokenizer_class = MODEL_CLASSES[
        args.model_type]
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    # load config in checkpoint
    global_step = 0
    consumed_samples = 0
    checkpoint_dir = os.path.join(args.output_dir, "model_last")
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            with open(os.path.join(checkpoint_dir, "./config.yml"), "r") as f:
                step_config = yaml.load(f, Loader=yaml.FullLoader)
                assert step_config[
                    "global_batch_size"] == args.global_batch_size, "Please ensure checkpoint global batch size is the same. Folder: {}".format(
                        checkpoint_dir)
                consumed_samples = step_config["consumed_samples"]
                global_step = step_config["global_step"]

    if args.model_name_or_path in pretrained_models_list:
        model_config = model_class.pretrained_init_configuration[
            args.model_name_or_path]
        model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
        model_config[
            "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
        model = model_class(base_class(**model_config))
    else:
        model = model_class.from_pretrained(
            args.model_name_or_path,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob)

    criterion = criterion_class()

    if worker_index == 0:
        # log the model config and args
        model_config_json = json.dumps(model.get_model_config(),
                                       ensure_ascii=False,
                                       indent=2)
        log_writer.add_text("model_config", model_config_json)
        args_dict = {"paddle commit id": str(paddle.version.commit)}
        for arg in vars(args):
            args_dict[arg] = str(getattr(args, arg))
        log_writer.add_text("args", json.dumps(args_dict, indent=2))

    # Create the learning_rate sheduler and optimizer
    if args.decay_steps is None:
        args.decay_steps = args.max_steps
    assert args.warmup_rate <= 1.0 and args.warmup_rate >= 0.0, "warmup_rate should be in [0, 1]"
    args.warmup_steps = args.warmup_rate * args.max_steps

    lr_scheduler = LinearAnnealingWithWarmupDecay(
        args.max_lr,
        args.min_lr,
        warmup_step=args.warmup_steps,
        decay_step=args.decay_steps,
        last_epoch=global_step)

    clip = None
    if args.grad_clip > 0:
        clip = paddle.fluid.clip.GradientClipByGlobalNorm(
            clip_norm=args.grad_clip)

    decay_param = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    logger.info("Using paddle.optimizer.AdamW.")
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler
        if lr_scheduler is not None else args.max_lr,
        beta1=args.adam_beta1,
        beta2=args.adam_beta2,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        grad_clip=clip,
        apply_decay_param_fun=lambda x: x in decay_param,
        multi_precision=args.use_amp)

    if args.use_amp:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
        scaler = fleet.distributed_scaler(scaler)
        model = paddle.amp.decorate(models=model,
                                    level='O2',
                                    save_dtype='float32')

    if paddle.distributed.get_world_size() > 1:
        model = fleet.distributed_model(model)
        optimizer = fleet.distributed_optimizer(optimizer)

    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

    data_file = get_train_data_file(args)

    train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
        args,
        data_file,
        tokenizer,
        data_world_size=worker_num,
        data_world_rank=worker_index,
        max_seq_len=args.max_seq_len,
        current_step=global_step)

    # load checkpoint vars
    if os.path.exists(checkpoint_dir):
        if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")):
            logger.info("Try to load checkpoint from %s " % checkpoint_dir)
            opt_path = os.path.join(checkpoint_dir, "model_state.pdopt")
            params_path = os.path.join(checkpoint_dir, "model_state.pdparams")

            if os.path.exists(opt_path):
                opt_dict = paddle.load(opt_path)
                optimizer.set_state_dict(opt_dict)
                model_dict = paddle.load(params_path)
                model.set_state_dict(model_dict)
            else:
                logger.warning("No optimizer checkpoint file found in %s." %
                               opt_path)
            logger.info(
                "Checkpoint loaded from global step: {}".format(global_step))

    loss_global = {
        "loss": paddle.to_tensor(0.0),
        "lm_loss": paddle.to_tensor(0.0),
        "sop_loss": paddle.to_tensor(0.0),
    }
    tic_train = time.time()
    while True:
        # If not call valid_data_loader, the enumerate will call valid_data_loader
        # many times. and start a new random dataloader.
        valid_data_loader = valid_data_loader()
        test_data_loader = test_data_loader()

        # time count
        train_reader_cost = 0.0
        train_run_cost = 0.0
        reader_start = time.time()

        for step, batch in enumerate(train_data_loader()):
            train_reader_cost += time.time() - reader_start
            train_start = time.time()

            # 0. input_ids,
            # 1. segment_ids,
            # 2. input_mask,
            # 3. masked_lm_positions,
            # 4. masked_lm_labels,
            # 5. next_sentence_labels

            input_ids, segment_ids, input_mask, masked_lm_positions, \
            masked_lm_labels, next_sentence_labels = batch

            with paddle.amp.auto_cast(args.use_amp,
                                      custom_black_list=[
                                          "reduce_sum",
                                          "c_softmax_with_cross_entropy",
                                          "elementwise_div"
                                      ],
                                      level='O2'):

                # Create the model for the ernie pretrain
                prediction_scores, seq_relationship_score = model(
                    input_ids=input_ids,
                    token_type_ids=segment_ids,
                    position_ids=None,
                    attention_mask=input_mask,
                    masked_positions=masked_lm_positions)

                lm_loss, sop_loss = criterion(prediction_scores,
                                              seq_relationship_score,
                                              masked_lm_labels,
                                              next_sentence_labels)
                loss = lm_loss + sop_loss

            if args.use_amp:
                scaler.scale(loss).backward()
                scaler.minimize(optimizer, loss)
            else:
                loss.backward()
                optimizer.step()

            optimizer.clear_grad()
            train_run_cost += time.time() - train_start

            # Skip for accumulate_steps in global step
            if (step + 1) % args.accumulate_steps != 0:
                continue

            global_step += 1

            loss_global["loss"] += loss.detach()
            loss_global["lm_loss"] += lm_loss.detach()
            loss_global["sop_loss"] += sop_loss.detach()

            if global_step % args.logging_freq == 0:
                log_info_dict = dict()
                log_info_dict["global_step"] = global_step
                for k, v in loss_global.items():
                    log_info_dict[k] = all_gather(v) / args.logging_freq
                    v.subtract_(v)
                if worker_index == 0:
                    speed = args.logging_freq / (time.time() - tic_train)
                    log_info_dict["learning_rate"] = lr_scheduler.get_lr()
                    log_info_dict["steps_per_second"] = speed
                    log_info_dict[
                        "samples_per_second"] = speed * args.global_batch_size

                    for k, v in log_info_dict.items():
                        log_writer.add_scalar("train/%s" % k, v, global_step)

                    common_loginfo = "global step %d, loss: %.9f, lm_loss: %.6f, sop_loss: %.6f, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e" % (
                        global_step, log_info_dict["loss"],
                        log_info_dict["lm_loss"], log_info_dict["sop_loss"],
                        speed, log_info_dict["samples_per_second"],
                        log_info_dict["learning_rate"])

                    addition_info = ""
                    if args.use_amp:
                        amp_info = {
                            "loss_scaling": scaler._scale.item(),
                            "incr_count": scaler._incr_count,
                            "decr_count": scaler._decr_count
                        }
                        addition_info = ", ".join("%s: %d" % (k, v)
                                                  for k, v in amp_info.items())
                        addition_info = " " + addition_info
                        for k, v in amp_info.items():
                            log_writer.add_scalar("amp/%s" % k, v, global_step)

                    logger.info(common_loginfo + addition_info)

                tic_train = time.time()

            if lr_scheduler is not None:
                lr_scheduler.step()

            if global_step % args.eval_freq == 0:
                # TODO, check the input data of validation

                run_evaluate(valid_data_loader,
                             model,
                             criterion,
                             args.eval_iters,
                             log_writer,
                             global_step,
                             args,
                             task_name="valid")
                tic_train = time.time()

            def save_ckpt(output_dir, model, tokenizer, args, global_step):
                step_config = {
                    "model_name": args.model_name_or_path,
                    "global_step": global_step,
                    "global_batch_size": args.global_batch_size,
                    "consumed_samples": global_step * args.global_batch_size,
                }

                logger.debug("saving models to {}".format(output_dir))
                model_to_save = model._layers if isinstance(
                    model, paddle.DataParallel) else model

                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                paddle.save(optimizer.state_dict(),
                            os.path.join(output_dir, "model_state.pdopt"))

                with open(os.path.join(output_dir, "config.yml"), "w") as f:
                    yaml.dump(step_config,
                              f,
                              encoding='utf-8',
                              allow_unicode=True)

            if global_step % args.save_steps == 0 or global_step >= args.max_steps:
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                if worker_index == 0:
                    save_ckpt(output_dir, model, tokenizer, args, global_step)

                if worker_num > 1:
                    paddle.distributed.barrier()
                tic_train = time.time()

            if global_step % args.checkpoint_steps == 0:
                output_dir = os.path.join(args.output_dir, "model_last")
                if worker_index == 0:
                    if not os.path.exists(output_dir):
                        os.mkdir(output_dir)
                    output_dir_bak = os.path.join(args.output_dir,
                                                  "model_last_bak")
                    if os.path.exists(output_dir):
                        if os.path.exists(output_dir_bak):
                            shutil.rmtree(output_dir_bak)
                        shutil.move(output_dir, output_dir_bak)
                        os.mkdir(output_dir)
                    save_ckpt(output_dir, model, tokenizer, args, global_step)

                if worker_num > 1:
                    paddle.distributed.barrier()

            if global_step >= args.max_steps:
                run_evaluate(test_data_loader,
                             model,
                             criterion,
                             args.test_iters,
                             log_writer,
                             global_step,
                             args,
                             task_name="test")
                del train_data_loader
                return
Example #29
0
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = '2.2.0'  # Maybe dev is better
import sys
if 'datasets' in sys.modules.keys():
    from paddlenlp.utils.log import logger
    logger.warning(
        "datasets module loaded before paddlenlp. "
        "This may cause PaddleNLP datasets to be unavalible in intranet.")
from . import data
from . import datasets
from . import embeddings
from . import ops
from . import layers
from . import metrics
from . import seq2vec
from . import transformers
from . import utils
from . import losses
from . import experimental
from .taskflow import Taskflow
import paddle
def do_train(args):
    # Initialize the paddle and paddle fleet execute environment
    paddle.enable_static()
    fleet.init(is_collective=True)

    # Create the random seed for the worker
    random.seed(args.seed)
    np.random.seed(args.seed)
    paddle.seed(args.seed)
    get_rng_state_tracker().add('global_seed', args.seed)
    get_rng_state_tracker().add('local_seed',
                                args.seed + fleet.worker_index() + 2021)

    assert args.device in [
        "cpu", "gpu", "xpu"
    ], "Invalid device! Available device should be cpu, gpu, or xpu."
    place = paddle.set_device(args.device)

    worker_num = fleet.worker_num()
    worker_index = fleet.worker_index()

    topo = Topology(device_rank=worker_index,
                    world_size=worker_num,
                    dp_degree=args.dp_degree,
                    pp_degree=args.pp_degree,
                    sharding_degree=args.sharding_degree,
                    mp_degree=args.mp_degree)

    logger.info("The topo of hybrid parallelism:\n{}".format(topo))

    dist_strategy = dist_optimizer(args, topo)

    # Create log write, train results show on last card of pipeline.
    if topo.is_last:
        log_writer_path = os.path.join(
            args.output_dir, "train_log",
            "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format(
                args.model_name_or_path, args.global_batch_size, args.use_amp,
                args.use_recompute, worker_index).lower())
        if os.path.exists(log_writer_path):
            import shutil
            shutil.rmtree(log_writer_path)
        log_writer = LogWriter(log_writer_path)

    # Define the input data in the static mode

    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    pretrained_models_list = list(
        model_class.pretrained_init_configuration.keys())

    data_file = get_train_data_file(args)
    main_program = paddle.static.default_main_program()
    startup_program = paddle.static.default_startup_program()
    with paddle.static.program_guard(main_program, startup_program):
        with paddle.utils.unique_name.guard():
            with paddle.static.device_guard('gpu:0'):
                data_holders = create_data_holder(args)
                [tokens, loss_mask, attention_mask, position_ids,
                 labels] = data_holders

                tokenizer = tokenizer_class.from_pretrained(
                    args.model_name_or_path)
                eos_id = tokenizer.eos_token_id

                train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset(
                    args,
                    data_file,
                    data_world_size=topo.data_info.size,
                    data_world_rank=topo.data_info.rank,
                    eos_id=eos_id,
                    max_seq_len=args.max_seq_len,
                    places=paddle.static.cuda_places(),
                    data_holders=data_holders,
                    pipeline_mode=False,
                )

                if args.model_name_or_path in pretrained_models_list:
                    model_config = model_class.pretrained_init_configuration[
                        args.model_name_or_path]

                    model_config[
                        "hidden_dropout_prob"] = args.hidden_dropout_prob
                    model_config[
                        "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
                    model_config["topo"] = topo

                    model = guard(f'gpu:{args.pp_degree -1}')(
                        GPTForPretraining)(
                            guard(f'gpu:0')(GPTModel)(**model_config))
                else:
                    model, _ = GPTForPretraining.from_pretrained(
                        args.model_name_or_path,
                        hidden_dropout_prob=args.hidden_dropout_prob,
                        attention_probs_dropout_prob=args.
                        attention_probs_dropout_prob,
                        topo=topo)

                # Create the model for the gpt pretrain
                preds = model(tokens, position_ids, attention_mask)

                criterion = guard(f'gpu:{args.pp_degree -1}')(
                    GPTPretrainingCriterion)(topo)
                loss = criterion(preds, labels, loss_mask)

            # Create the learning_rate sheduler and optimizer
            if args.decay_steps is None:
                args.decay_steps = args.max_steps
            warmup_step = args.warmup_rate * args.decay_steps

            # TODO @ZHUI Use paddle network to support lr scheduler
            lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
                max_lr=args.max_lr,
                min_lr=args.min_lr,
                warmup_step=warmup_step,
                decay_step=args.decay_steps)

            clip = None
            if args.grad_clip > 0:
                clip = paddle.fluid.clip.GradientClipByGlobalNorm(
                    clip_norm=args.grad_clip)

            decay_param = [
                p.name for n, p in model.named_parameters()
                if not any(nd in n for nd in ["bias", "norm"])
            ]

            optimizer = paddle.optimizer.AdamW(
                learning_rate=lr_scheduler,
                beta1=args.adam_beta1,
                beta2=args.adam_beta2,
                epsilon=args.adam_epsilon,
                grad_clip=clip,
                weight_decay=args.weight_decay,
                apply_decay_param_fun=lambda x: x in decay_param)
            # alias
            optimizer.apply_optimize = optimizer._apply_optimize

            if args.use_recompute:
                dist_strategy.recompute = True
                dist_strategy.recompute_configs = {
                    "checkpoints": model.gpt.checkpoints
                }

            # Use the fleet api to compile the distributed optimizer
            optimizer = fleet.distributed_optimizer(optimizer,
                                                    strategy=dist_strategy)

            optimizer.minimize(loss)
            logger.info(f'final strategy: {fleet._final_strategy()}')
            logger.info("The training meta optimizer is/are %s" %
                        fleet._get_applied_meta_list())

    program_desc_dir = os.path.join(args.output_dir, "program_desc")
    if not os.path.isdir(program_desc_dir):
        os.mkdir(program_desc_dir)

    with open(program_desc_dir + "/main_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(main_program))

    with open(program_desc_dir + "/startup_program.txt.%d" % worker_index,
              'w') as f:
        f.write(str(startup_program))

    # Define the Executor for running the static model
    exe = paddle.static.Executor(place)
    exe.run(startup_program)
    test_program = main_program.clone(for_test=True)

    if args.model_name_or_path not in pretrained_models_list:
        logger.info("Try to load checkpoint from %s " %
                    args.model_name_or_path)
        dygrah_path = os.path.join(args.model_name_or_path,
                                   "model_state.pdparams")
        static_path = os.path.join(args.model_name_or_path, "static_vars")

        flag_loaded = False
        if os.path.exists(static_path):
            if args.mp_degree > 1:
                logger.warning("MP should init with dygraph params")
            else:
                logger.info("Loading parameters from %s" % static_path)
                paddle.static.load(main_program, static_path, exe)
                flag_loaded = True

        if not flag_loaded and os.path.exists(dygrah_path):
            if args.sharding_degree > 1:
                logger.warning("Sharding should init with static vars")
            else:
                logger.info("Loading parameters from %s" % dygrah_path)
                init_static_with_params(
                    model, paddle.load(dygrah_path, return_numpy=True), topo,
                    main_program)
                flag_loaded = True

        if not flag_loaded:
            logger.error("No checkpoint load.")

    global_step = 0
    tic_train = time.time()
    epoch = 0
    learning_rate = main_program.global_block().vars["learning_rate_0"]
    while True:
        fetchs = []
        if topo.is_last:
            fetchs = [loss, learning_rate]

        # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
        # many times. and start a new random dataloader.
        valid_data_loader = valid_data_loader()
        test_data_loader = test_data_loader()

        for step, batch in enumerate(train_data_loader()):
            global_step += 1
            ret = exe.run(main_program,
                          feed=batch,
                          fetch_list=fetchs,
                          use_program_cache=True)
            # In the new 2.0 api, must call this function to change the learning_rate
            lr_scheduler.step()

            if global_step % args.logging_freq == 0:
                if topo.is_last:
                    loss_return, lr_return = ret
                    speed = args.logging_freq / (time.time() - tic_train)
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %.9f, speed: %.2f steps/s, ips: %.0f tokens/s, learning rate: %.5e"
                        % (global_step, epoch, step, loss_return[0], speed,
                           speed * args.global_batch_size * args.max_seq_len,
                           lr_return[0]))
                    log_writer.add_scalar("loss", loss_return[0], global_step)
                    log_writer.add_scalar("learning_rate", lr_return[0],
                                          global_step)
                tic_train = time.time()

            if args.check_accuracy:
                if global_step >= args.max_steps:
                    return
                else:
                    continue

            if global_step % args.eval_freq == 0:
                # TODO, check the input data of validation
                eval_fetch = []
                if topo.is_last:
                    eval_fetch = [loss]

                run_evaluate(valid_data_loader, exe, test_program,
                             args.eval_iters, log_writer, global_step, args,
                             epoch, topo.is_last, eval_fetch, "valid")
                tic_train = time.time()

            if global_step % args.save_steps == 0 or global_step >= args.max_steps:
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                logger.debug("saving models to {}".format(output_dir))
                save_persistables(exe, os.path.join(output_dir, "static_vars"),
                                  main_program)
                if global_step == args.save_steps:
                    model.init_config["init_args"][0].init_config.pop(
                        "topo", None)
                model.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
                tic_train = time.time()

            if global_step >= args.max_steps:
                eval_fetch = []
                if topo.is_last:
                    eval_fetch = [loss]

                run_evaluate(test_data_loader, exe, test_program,
                             args.test_iters, log_writer, global_step, args,
                             epoch, topo.is_last, eval_fetch, "test")
                del train_data_loader
                return
        epoch += 1