Exemplo n.º 1
0
 def _init_deep_model(self, model_type, model_path, num_labels, num_regs=None):
     if 'roberta' in model_type:
         tokenizer = RobertaTokenizer.from_pretrained(model_path)
         config = RobertaConfig.from_pretrained(model_path)
         config.num_labels = num_labels
         model = RobertaForSequenceClassification.from_pretrained(model_path, config=config)
         model.eval()
         model.to(self.device)
     elif 'electra_multitask' in model_type:
         tokenizer = ElectraTokenizer.from_pretrained(model_path)
         tokenizer.add_special_tokens({'additional_special_tokens': ['[VALUES]']})
         config = ElectraConfig.from_pretrained(model_path)
         config.num_labels = num_labels
         config.num_regs = num_regs
         config.vocab_size = len(tokenizer)
         model = ElectraForSequenceClassificationMultiTask.from_pretrained(model_path, config=config)
         model.eval()
         model.to(self.device)
     elif 'electra' in model_type:
         tokenizer = ElectraTokenizer.from_pretrained(model_path)
         config = ElectraConfig.from_pretrained(model_path)
         config.num_labels = num_labels
         model = ElectraForSequenceClassification.from_pretrained(model_path, config=config)
         model.eval()
         model.to(self.device)
     else:
         raise NotImplementedError()
     return config, tokenizer, model
Exemplo n.º 2
0
 def _load_model(self):
     config = ElectraConfig.from_pretrained(self.backbone)
     p_encoder = ElectraEncoder.from_pretrained(self.backbone,
                                                config=config).cuda()
     q_encoder = ElectraEncoder.from_pretrained(self.backbone,
                                                config=config).cuda()
     return p_encoder, q_encoder
Exemplo n.º 3
0
    def __init__(self):
        self.root_path = '..'
        self.checkpoint_path = f"{self.root_path}/checkpoint"
        self.save_ckpt_path = f"{self.checkpoint_path}/koelectra-wellnesee-text-classification.pth"
        model_name_or_path = "monologg/koelectra-base-discriminator"

        # 답변과 카테고리 불러오기
        self.category, self.answer = load_wellness_answer()

        ctx = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(ctx)

        # 저장한 Checkpoint 불러오기
        checkpoint = torch.load(self.save_ckpt_path, map_location=self.device)

        # Electra Tokenizer
        self.tokenizer = ElectraTokenizer.from_pretrained(model_name_or_path)

        electra_config = ElectraConfig.from_pretrained(model_name_or_path)
        self.model = koElectraForSequenceClassification.from_pretrained(
            pretrained_model_name_or_path=model_name_or_path,
            config=electra_config,
            num_labels=359)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(self.device)
        self.model.eval()
def get_model_and_tokenizer(model_name, device):
    save_ckpt_path = CHECK_POINT[model_name]

    if model_name == "koelectra":
        model_name_or_path = "monologg/koelectra-base-discriminator"

        tokenizer = ElectraTokenizer.from_pretrained(model_name_or_path)
        electra_config = ElectraConfig.from_pretrained(model_name_or_path)
        model = koElectraForSequenceClassification.from_pretrained(
            pretrained_model_name_or_path=model_name_or_path,
            config=electra_config,
            num_labels=359)
    elif model_name == 'kobert':
        tokenizer = get_tokenizer()
        model = KoBERTforSequenceClassfication()

    if os.path.isfile(save_ckpt_path):
        checkpoint = torch.load(save_ckpt_path, map_location=device)
        pre_epoch = checkpoint['epoch']
        # pre_loss = checkpoint['loss']
        model.load_state_dict(checkpoint['model_state_dict'])

        print(f"load pretrain from: {save_ckpt_path}, epoch={pre_epoch}")

    return model, tokenizer
Exemplo n.º 5
0
    def __init__(self, root_path='../ai/chatbot'):
        checkpoint_path = f"{root_path}/checkpoint"
        self.model_path = f"{checkpoint_path}/koelectra-wellness-text-classification.pth"
        model_name_or_path = "monologg/koelectra-base-discriminator"

        checkpoint = torch.load(self.model_path, map_location=device)
        electra_config = ElectraConfig.from_pretrained(model_name_or_path)
        self.model = koElectraForSequenceClassification.from_pretrained(pretrained_model_name_or_path=model_name_or_path, config=electra_config, num_labels=359)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(device)
        self.model.eval()

        self.tokenizer = ElectraTokenizer.from_pretrained(model_name_or_path)

        self.category = []
        idx = -1
        with open(root_path+'/data/wellness_data_for_text_classification.txt', 'r') as f:
            while True:
                line = f.readline()
                if not line:
                    break
                datas = line.strip().split("\t")
                if datas[1] != str(idx):
                    self.category.append(datas[2])
                idx += 1
def predict_pair(model_args, data_args, training_args):
    # Set seed
    set_seed(training_args.seed)

    if 'roberta' in model_args.model_type:
        tokenizer = RobertaTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
        config = RobertaConfig.from_pretrained(model_args.model_name_or_path)
        config.num_labels = data_args.num_labels
        model = RobertaForSequenceClassification.from_pretrained(model_args.model_name_or_path, config=config)
    elif 'electra' in model_args.model_type:
        tokenizer = ElectraTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
        config = ElectraConfig.from_pretrained(model_args.model_name_or_path)
        config.num_labels = data_args.num_labels
        model = ElectraForSequenceClassification.from_pretrained(model_args.model_name_or_path, config=config)
    else:
        # default -> bert
        tokenizer = BertTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
        config = BertConfig.from_pretrained(model_args.model_name_or_path)
        config.num_labels = data_args.num_labels
        model = BertForSequenceClassification.from_pretrained(model_args.model_name_or_path, config=config)

    model.to(training_args.device)

    test_df = pickle.load(open(data_args.test_data_file, 'rb'))
    test_dataset = get_dataset(data_args, tokenizer, test_df, model_args.model_type)
    data_collator = MyDataCollator()
    if training_args.local_rank != -1:
        sampler = SequentialDistributedSampler(test_dataset)
        model = torch.nn.DataParallel(model)
    else:
        n_gpu = torch.cuda.device_count()
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)
        sampler = SequentialSampler(test_dataset)
    print(len(test_dataset))
    dataloader = DataLoader(
        test_dataset,
        sampler=sampler,
        batch_size=training_args.eval_batch_size,
        collate_fn=data_collator,
    )

    model.eval()
    all_probs = []
    for inputs in tqdm(dataloader):
        for k, v in inputs.items():
            inputs[k] = v.to(training_args.device)
        inputs.pop('labels')
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs[0]
            probs = torch.softmax(logits, dim=-1)
            maxp, maxi = torch.max(probs, dim=-1)
            result = [(_i, _p) for _p, _i in zip(maxp, maxi)]
            all_probs.extend(result)

    with open('./{}_{}.answer_classify.result'.format(data_args.data_type, model_args.model_type), 'w', encoding='utf-8') as fout:
        for i in range(len(test_df)):
            fout.write('{} | {} | {} | {} | {}\n'.format(test_df[i][0], test_df[i][1], test_df[i][2], all_probs[i][0], all_probs[i][1]))
Exemplo n.º 7
0
 def __init__(self, output_size=24005, device='cpu'):
     super().__init__()
     self.device = device
     config = ElectraConfig.from_pretrained(
         'google/electra-small-discriminator')
     self.electra = AutoModel.from_config(config).to(device)
     self.output = nn.Linear(self.electra.config.hidden_size,
                             output_size).to(device)
Exemplo n.º 8
0
def get_electra():
    ids = keras.layers.Input(shape=(None, ), dtype=tf.int32, name='ids')
    att = keras.layers.Input(shape=(None, ), dtype=tf.int32, name='att')
    tok_type_ids = keras.layers.Input(shape=(None, ),
                                      dtype=tf.int32,
                                      name='tti')

    config = ElectraConfig.from_pretrained(Config.Electra.config)
    electra_model = TFElectraModel.from_pretrained(Config.Electra.model,
                                                   config=config)

    x = electra_model(ids, attention_mask=att, token_type_ids=tok_type_ids)

    x1 = keras.layers.Dropout(0.15)(x[0])
    x1 = keras.layers.Conv1D(768, 2, padding='same')(x1)
    x1 = keras.layers.LeakyReLU()(x1)
    x1 = keras.layers.LayerNormalization()(x1)
    x1 = keras.layers.Conv1D(64, 2, padding='same')(x1)
    x1 = keras.layers.LeakyReLU()(x1)
    x1 = keras.layers.LayerNormalization()(x1)
    x1 = keras.layers.Conv1D(32, 2, padding='same')(x1)
    x1 = keras.layers.Conv1D(1, 1)(x1)
    x1 = keras.layers.Flatten()(x1)
    x1 = keras.layers.Activation('softmax', dtype='float32', name='sts')(x1)

    x2 = keras.layers.Dropout(0.15)(x[0])
    x2 = keras.layers.Conv1D(768, 2, padding='same')(x2)
    x2 = keras.layers.LeakyReLU()(x2)
    x2 = keras.layers.LayerNormalization()(x2)
    x2 = keras.layers.Conv1D(64, 2, padding='same')(x2)
    x2 = keras.layers.LeakyReLU()(x2)
    x2 = keras.layers.LayerNormalization()(x2)
    x2 = keras.layers.Conv1D(32, 2, padding='same')(x2)
    x2 = keras.layers.Conv1D(1, 1)(x2)
    x2 = keras.layers.Flatten()(x2)
    x2 = keras.layers.Activation('softmax', dtype='float32', name='ets')(x2)

    model = keras.models.Model(inputs=[ids, att, tok_type_ids],
                               outputs=[x1, x2])

    optimizer = keras.optimizers.Adam(learning_rate=6e-5)
    if Config.Train.use_amp:
        optimizer = keras.mixed_precision.experimental.LossScaleOptimizer(
            optimizer, 'dynamic')
    loss = keras.losses.CategoricalCrossentropy(
        label_smoothing=Config.Train.label_smoothing)
    model.compile(loss=loss, optimizer=optimizer)

    return model
Exemplo n.º 9
0
 def bert_config(self):
     if self.bert_model_name.startswith('bert-'):
         return BertConfig.from_pretrained(self.bert_model_name,
                                           cache_dir=self.bert_cache_dir)
     elif 'roberta' in self.bert_model_name:
         return RobertaConfig.from_pretrained(self.bert_model_name,
                                              cache_dir=self.bert_cache_dir)
     elif self.bert_model_name.startswith('xlm-roberta-'):
         return XLMRobertaConfig.from_pretrained(
             self.bert_model_name, cache_dir=self.bert_cache_dir)
     elif 'electra' in self.bert_model_name:
         return ElectraConfig.from_pretrained(self.bert_model_name,
                                              cache_dir=self.bert_cache_dir)
     else:
         raise ValueError('Unknown model: {}'.format(self.bert_model_name))
Exemplo n.º 10
0
def define_config(name):
    if name in [
            "bert-base-multilingual-cased",
            "sangrimlee/bert-base-multilingual-cased-korquad",
            "kykim/bert-kor-base", "monologg/kobert"
    ]:
        return BertConfig.from_pretrained(name)
    elif name in [
            "monologg/koelectra-base-v3-discriminator",
            "kykim/electra-kor-base"
    ]:
        return ElectraConfig.from_pretrained(name)
    elif name in ["xlm-roberta-large"]:
        return XLMRobertaConfig.from_pretrained(name)
    elif name in ["kykim/funnel-kor-base"]:
        return FunnelConfig.from_pretrained(name)
Exemplo n.º 11
0
def _get_bert(model_type, model_path_dict):
    if model_type == 'bert':
        config = BertConfig.from_pretrained(model_path_dict['config'])
        config.output_hidden_states = True
        bert = BertModel.from_pretrained(model_path_dict['model'],
                                         config=config)
    elif model_type == 'electra':
        config = ElectraConfig.from_pretrained(model_path_dict['config'])
        config.output_hidden_states = True
        bert = ElectraModel.from_pretrained(model_path_dict['model'],
                                            config=config)
    elif model_type == 'roberta':
        config = RobertaConfig.from_pretrained(model_path_dict['config'])
        config.output_hidden_states = True
        bert = RobertaModel.from_pretrained(model_path_dict['model'],
                                            config=config)
    return bert, config
Exemplo n.º 12
0
def load_model(dataBunch, pretrained_path, finetuned_wgts_path, device, multi_label):

    model_type = dataBunch.model_type
    model_state_dict = None

    if torch.cuda.is_available():
        map_location = lambda storage, loc: storage.cuda()
    else:
        map_location = "cpu"

    if finetuned_wgts_path:
        model_state_dict = torch.load(finetuned_wgts_path, map_location=map_location)
    else:
        model_state_dict = None

    if multi_label is True:
        config_class, model_class, _ = MODEL_CLASSES[model_type]

        config = config_class.from_pretrained(
            str(pretrained_path), num_labels=len(dataBunch.labels)
        )

        model = model_class[1].from_pretrained(
            str(pretrained_path), config=config, state_dict=model_state_dict
        )
    else:
        if model_type == "electra":
            config = ElectraConfig.from_pretrained(
                str(pretrained_path),
                model_type=model_type,
                num_labels=len(dataBunch.labels),
            )
        else:
            config = AutoConfig.from_pretrained(
                str(pretrained_path),
                model_type=model_type,
                num_labels=len(dataBunch.labels),
            )
        model = AutoModelForSequenceClassification.from_pretrained(
            str(pretrained_path), config=config, state_dict=model_state_dict
        )

    return model.to(device)
Exemplo n.º 13
0
 def __init__(self, params, name="model", **kwargs):
     super(NERwithHFBERT, self).__init__(params, name=name, **kwargs)
     self._tag_string_mapper = get_sm(self._params.tags_fn_)
     self.tag_vocab_size = self._tag_string_mapper.size() + 2
     self._tracked_layers = dict()
     if self.pretrained_bert is None:
         if self._params.use_hf_electra_model_:
             self.pretrained_bert = TFElectraModel(ElectraConfig.from_pretrained(params.pretrained_hf_model_,cache_dir=params.hf_cache_dir_))
         else:
             self.pretrained_bert = TFBertModel(BertConfig.from_pretrained(params.pretrained_hf_model_,cache_dir=params.hf_cache_dir_))
     self._dropout = tf.keras.layers.Dropout(self._params.dropout_last)
     if self._params.bet_tagging_:
         # print(self.tag_vocab_size-1)
         # half of the classes is used plus O-Class, sos, eos
         self._layer_cls = tf.keras.layers.Dense(
             int(self._tag_string_mapper.size() // 2 + 3), activation=tf.keras.activations.softmax, name="layer_cls"
         )
         self._layer_start = tf.keras.layers.Dense(1, activation=tf.keras.activations.sigmoid, name="layer_start")
         self._layer_end = tf.keras.layers.Dense(1, activation=tf.keras.activations.sigmoid, name="layer_end")
     elif self._params.use_crf:
         self._last_layer = tf.keras.layers.Dense(self.tag_vocab_size, name="last_layer")
         self._trans_params = tf.keras.layers.Embedding(
             self.tag_vocab_size, self.tag_vocab_size, name="trans_params"
         )
         # ,embeddings_initializer=tf.keras.initializers.Constant(1))
         if self._params.crf_with_ner_rule:
             self._penalty_factor = tf.keras.layers.Embedding(1, 1, name="penalty_factor")
             # ,embeddings_initializer=tf.keras.initializers.Constant(1))
             self._penalty_absolute = tf.keras.layers.Embedding(1, 1, name="penalty_absolute")
             # ,embeddings_initializer=tf.keras.initializers.Constant(1))
         elif self.params.crf_with_ner_forb_trans:
             self._penalty_factor = tf.constant(0.0, name="penalty_factor", dtype=tf.float32)
             self._penalty_absolute = tf.constant(-100000.0, name="penalty_absolute", dtype=tf.float32)
         self.init_crf_with_ner_rule((self.tag_vocab_size - 3) // 2)
     else:
         self._last_layer = tf.keras.layers.Dense(
             self.tag_vocab_size, activation=tf.keras.activations.softmax, name="last_layer"
         )
import tensorflow as tf
from transformers import (
    ElectraConfig,
    ElectraTokenizer,
    TFElectraForMaskedLM,
    TFElectraForPreTraining,
)

from electra.utils import colorize_dis, colorize_gen

os.environ["CUDA_VISIBLE_DEVICES"] = ""

# TODO: Should I use bert-base-uncased?
tokenizer = ElectraTokenizer.from_pretrained("bert-base-uncased")

gen_config = ElectraConfig.from_pretrained("google/electra-small-generator")
dis_config = ElectraConfig.from_pretrained(
    "google/electra-small-discriminator")

# gen = TFElectraForMaskedLM.from_pretrained("google/electra-small-generator")
# dis = TFElectraForPreTraining.from_pretrained("google/electra-small-discriminator")
gen = TFElectraForMaskedLM(config=gen_config)
dis = TFElectraForPreTraining(config=dis_config)
optimizer = tf.keras.optimizers.Adam(lr=1e-4)

# Load in WikiText-2.
filename = "/fsx/wikitext/wikitext-2-raw/wiki.test.raw"
with open(filename) as infile:
    wiki_text: str = infile.read()  # length 1,288,556

# Load in text strings.
Exemplo n.º 15
0
    c.lr = 1e-4
    c.layer_lr_decay = 0.8
    c.max_length = 512
elif c.size == "large":
    c.lr = 5e-5
    c.layer_lr_decay = 0.9
    c.max_length = 512
else:
    raise ValueError(f"Invalid size {c.size}")
if c.pretrained_checkpoint is None:
    c.max_length = 512  # All public models is ++, which use max_length 512

# huggingface/transformers
hf_tokenizer = ElectraTokenizerFast.from_pretrained(
    f"google/electra-{c.size}-discriminator")
electra_config = ElectraConfig.from_pretrained(
    f"google/electra-{c.size}-discriminator")

# wsc
if c.wsc_trick:
    from _utils.wsc_trick import *  # importing spacy model takes time

# logging
# light logging callback here is to only log the last score and avoid exceeding the api access limit
if c.logger == "neptune":
    import neptune
    from fastai.callback.neptune import NeptuneCallback

    class LightNeptuneCallback(NeptuneCallback):
        def after_batch(self):
            pass
Exemplo n.º 16
0
def get_model(args, tokenizer):
    config = ElectraConfig.from_pretrained('google/electra-base-discriminator')
    config.num_labels = 4
    config.vocab_size = tokenizer.get_vocab_size() if tokenizer else args.vocab_size
    model = ElectraForSequenceClassification(config)
    return model
Exemplo n.º 17
0
  save_ckpt_path = f"{checkpoint_path}/koelectra-wellnesee-text-classification.pth"
  model_name_or_path = "monologg/koelectra-base-discriminator"

  #답변과 카테고리 불러오기
  category, answer = load_wellness_answer()

  ctx = "cuda" if torch.cuda.is_available() else "cpu"
  device = torch.device(ctx)

  # 저장한 Checkpoint 불러오기
  checkpoint = torch.load(save_ckpt_path, map_location=device)

  # Electra Tokenizer
  tokenizer = ElectraTokenizer.from_pretrained(model_name_or_path)

  electra_config = ElectraConfig.from_pretrained(model_name_or_path)
  model = koElectraForSequenceClassification.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
                                                             config=electra_config,
                                                             num_labels=359)
  model.load_state_dict(checkpoint['model_state_dict'])
  model.to(device)
  model.eval()


  while 1:
    sent = input('\nQuestion: ') # '요즘 기분이 우울한 느낌이에요'
    data = koelectra_input(tokenizer,sent, device,512)
    # print(data)

    output = model(**data)
    def __init__(
        self,
        model_type,
        model_name,
        generator_name=None,
        discriminator_name=None,
        train_files=None,
        args=None,
        use_cuda=True,
        cuda_device=-1,
        **kwargs,
    ):

        """
        Initializes a LanguageModelingModel.

        Args:
            model_type: The type of model (gpt2, openai-gpt, bert, roberta, distilbert, camembert)
            model_name: Default Transformer model name or path to a directory containing Transformer model file (pytorch_nodel.bin).
            generator_name (optional): A pretrained model name or path to a directory containing an ELECTRA generator model.
            discriminator_name (optional): A pretrained model name or path to a directory containing an ELECTRA discriminator model.
            args (optional): Default args will be used if this parameter is not provided. If provided, it should be a dict containing the args that should be changed in the default args.
            train_files (optional): List of files to be used when training the tokenizer.
            use_cuda (optional): Use GPU if available. Setting to False will force model to use CPU only.
            cuda_device (optional): Specific GPU that should be used. Will use the first available GPU by default.
            **kwargs (optional): For providing proxies, force_download, resume_download, cache_dir and other options specific to the 'from_pretrained' implementation where this will be supplied.
        """  # noqa: ignore flake8"

        self.args = self._load_model_args(model_name)

        if isinstance(args, dict):
            self.args.update_from_dict(args)
        elif isinstance(args, LanguageModelingArgs):
            self.args = args

        if "sweep_config" in kwargs:
            sweep_config = kwargs.pop("sweep_config")
            sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"}
            self.args.update_from_dict(sweep_values)

        if self.args.manual_seed:
            random.seed(self.args.manual_seed)
            np.random.seed(self.args.manual_seed)
            torch.manual_seed(self.args.manual_seed)
            if self.args.n_gpu > 0:
                torch.cuda.manual_seed_all(self.args.manual_seed)

        if self.args.local_rank != -1:
            logger.info(f"local_rank: {self.args.local_rank}")
            torch.distributed.init_process_group(backend="nccl")
            cuda_device = self.args.local_rank

        if use_cuda:
            if torch.cuda.is_available():
                if cuda_device == -1:
                    self.device = torch.device("cuda")
                else:
                    self.device = torch.device(f"cuda:{cuda_device}")
            else:
                raise ValueError(
                    "'use_cuda' set to True when cuda is unavailable."
                    " Make sure CUDA is available or set use_cuda=False."
                )
        else:
            self.device = "cpu"

        self.results = {}

        if not use_cuda:
            self.args.fp16 = False

        self.args.model_name = model_name
        self.args.model_type = model_type

        config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type]
        self.tokenizer_class = tokenizer_class
        new_tokenizer = False

        if self.args.tokenizer_name:
            self.tokenizer = tokenizer_class.from_pretrained(self.args.tokenizer_name, cache_dir=self.args.cache_dir)
        elif self.args.model_name:
            if self.args.model_name == "electra":
                self.tokenizer = tokenizer_class.from_pretrained(
                    generator_name, cache_dir=self.args.cache_dir, **kwargs
                )
                self.args.tokenizer_name = self.args.model_name
            else:
                self.tokenizer = tokenizer_class.from_pretrained(model_name, cache_dir=self.args.cache_dir, **kwargs)
                self.args.tokenizer_name = self.args.model_name
        else:
            if not train_files:
                raise ValueError(
                    "model_name and tokenizer_name are not specified."
                    "You must specify train_files to train a Tokenizer."
                )
            else:
                self.train_tokenizer(train_files)
                new_tokenizer = True

        if self.args.config_name:
            self.config = config_class.from_pretrained(self.args.config_name, cache_dir=self.args.cache_dir)
        elif self.args.model_name and self.args.model_name != "electra":
            self.config = config_class.from_pretrained(model_name, cache_dir=self.args.cache_dir, **kwargs)
        else:
            self.config = config_class(**self.args.config, **kwargs)
        if self.args.vocab_size:
            self.config.vocab_size = self.args.vocab_size
        if new_tokenizer:
            self.config.vocab_size = len(self.tokenizer)

        if self.args.model_type == "electra":
            if generator_name:
                self.generator_config = ElectraConfig.from_pretrained(generator_name)
            elif self.args.model_name:
                self.generator_config = ElectraConfig.from_pretrained(
                    os.path.join(self.args.model_name, "generator_config"), **kwargs,
                )
            else:
                self.generator_config = ElectraConfig(**self.args.generator_config, **kwargs)
                if new_tokenizer:
                    self.generator_config.vocab_size = len(self.tokenizer)

            if discriminator_name:
                self.discriminator_config = ElectraConfig.from_pretrained(discriminator_name)
            elif self.args.model_name:
                self.discriminator_config = ElectraConfig.from_pretrained(
                    os.path.join(self.args.model_name, "discriminator_config"), **kwargs,
                )
            else:
                self.discriminator_config = ElectraConfig(**self.args.discriminator_config, **kwargs)
                if new_tokenizer:
                    self.discriminator_config.vocab_size = len(self.tokenizer)

        if self.args.block_size <= 0:
            self.args.block_size = min(self.args.max_seq_length, self.tokenizer.max_len)
        else:
            self.args.block_size = min(self.args.block_size, self.tokenizer.max_len, self.args.max_seq_length)

        if self.args.model_name:
            if self.args.model_type == "electra":
                if self.args.model_name == "electra":
                    generator_model = ElectraForMaskedLM.from_pretrained(generator_name)
                    discriminator_model = ElectraForPreTraining.from_pretrained(discriminator_name)
                    self.model = ElectraForLanguageModelingModel(
                        config=self.config,
                        generator_model=generator_model,
                        discriminator_model=discriminator_model,
                        generator_config=self.generator_config,
                        discriminator_config=self.discriminator_config,
                        tie_generator_and_discriminator_embeddings=self.args.tie_generator_and_discriminator_embeddings,
                    )
                    model_to_resize = (
                        self.model.generator_model.module
                        if hasattr(self.model.generator_model, "module")
                        else self.model.generator_model
                    )
                    model_to_resize.resize_token_embeddings(len(self.tokenizer))

                    model_to_resize = (
                        self.model.discriminator_model.module
                        if hasattr(self.model.discriminator_model, "module")
                        else self.model.discriminator_model
                    )
                    model_to_resize.resize_token_embeddings(len(self.tokenizer))
                    self.model.generator_model = generator_model
                    self.model.discriminator_model = discriminator_model
                else:
                    self.model = model_class.from_pretrained(
                        model_name,
                        config=self.config,
                        cache_dir=self.args.cache_dir,
                        generator_config=self.generator_config,
                        discriminator_config=self.discriminator_config,
                        **kwargs,
                    )
                    self.model.load_state_dict(torch.load(os.path.join(self.args.model_name, "pytorch_model.bin")))
            else:
                self.model = model_class.from_pretrained(
                    model_name, config=self.config, cache_dir=self.args.cache_dir, **kwargs,
                )
        else:
            logger.info(" Training language model from scratch")
            if self.args.model_type == "electra":
                generator_model = ElectraForMaskedLM(config=self.generator_config)
                discriminator_model = ElectraForPreTraining(config=self.discriminator_config)
                self.model = ElectraForLanguageModelingModel(
                    config=self.config,
                    generator_model=generator_model,
                    discriminator_model=discriminator_model,
                    generator_config=self.generator_config,
                    discriminator_config=self.discriminator_config,
                    tie_generator_and_discriminator_embeddings=self.args.tie_generator_and_discriminator_embeddings,
                )
                model_to_resize = (
                    self.model.generator_model.module
                    if hasattr(self.model.generator_model, "module")
                    else self.model.generator_model
                )
                model_to_resize.resize_token_embeddings(len(self.tokenizer))

                model_to_resize = (
                    self.model.discriminator_model.module
                    if hasattr(self.model.discriminator_model, "module")
                    else self.model.discriminator_model
                )
                model_to_resize.resize_token_embeddings(len(self.tokenizer))
            else:
                self.model = model_class(config=self.config)
                model_to_resize = self.model.module if hasattr(self.model, "module") else self.model
                model_to_resize.resize_token_embeddings(len(self.tokenizer))

        if model_type in ["camembert", "xlmroberta"]:
            warnings.warn(
                f"use_multiprocessing automatically disabled as {model_type}"
                " fails when using multiprocessing for feature conversion."
            )
            self.args.use_multiprocessing = False

        if self.args.wandb_project and not wandb_available:
            warnings.warn("wandb_project specified but wandb is not available. Wandb disabled.")
            self.args.wandb_project = None
Exemplo n.º 19
0
def main(cli_args):
    args = AttrDict(cli_args)
    logger.info("Training/evaluation parameters {}".format(args))

    args.output_dir = os.path.join(args.ckpt_dir, args.task)

    set_seed(args)

    output_mode = "classification"
    if "nsmc" in args.train_file:
        processor = NSMCProcessor(args)
    elif "kornli" in args.train_file:
        processor = KorNLIProcessor(args)
    elif "paws" in args.train_file:
        processor = PawsProcessor(args)
    elif "korsts" in args.train_file:
        processor = KorSTSProcessor(args)
        output_mode = "regression"
    elif "question-pair" in args.train_file:
        processor = QuestionPairProcessor(args)
    elif "hate-speech" in args.train_file:
        processor = HateSpeechProcessor(args)
    elif "naver-ner" in args.train_file:
        processor = NaverNerProcessor(args)
    else:
        processor = IntentProcessor(args)
    args["output_mode"] = output_mode
    labels = processor.get_labels()

    config = ElectraConfig.from_pretrained(
        args.model_name_or_path,
        num_labels=len(labels),
        id2label={str(i): label for i, label in enumerate(labels)},
        label2id={label: i for i, label in enumerate(labels)},
    )
    if args.mecab:
        tokenizer = KoNLPyBertTokenizer(
            konlpy_wordpiece=KoNLPyWordPieceTokenizer(Mecab(), use_tag=False),
            vocab_file=os.path.join(args.model_name_or_path, "vocab.txt"),
            do_lower_case=args.do_lower_case,
        )
    else:
        tokenizer = ElectraTokenizer.from_pretrained(
            args.model_name_or_path, do_lower_case=args.do_lower_case
        )

    if "naver-ner" in args.train_file:
        model = ElectraForTokenClassification.from_pretrained(
            args.model_name_or_path, config=config
        )
    else:
        model = ElectraForSequenceClassification.from_pretrained(
            args.model_name_or_path, config=config
        )

    #Re-init
    if args.do_reinit:
        init_layer(model.electra.encoder.layer, top_n_layer=1)
    
    # GPU or CPU
    args.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
    model.to(args.device)

    # Load dataset
    if "naver-ner" in args.train_file:
        train_dataset = (
            ner_load_and_cache_examples(args, tokenizer, mode="train")
            if args.train_file
            else None
        )
        dev_dataset = (
            ner_load_and_cache_examples(args, tokenizer, mode="dev")
            if args.dev_file
            else None
        )
        test_dataset = (
            ner_load_and_cache_examples(args, tokenizer, mode="test")
            if args.test_file
            else None
        )
    else:
        train_dataset = (
            seq_cls_load_and_cache_examples(args, tokenizer, mode="train")
            if args.train_file
            else None
        )
        dev_dataset = (
            seq_cls_load_and_cache_examples(args, tokenizer, mode="dev")
            if args.dev_file
            else None
        )
        test_dataset = (
            seq_cls_load_and_cache_examples(args, tokenizer, mode="test")
            if args.test_file
            else None
        )

    if dev_dataset == None:
        args.evaluate_test_during_training = (
            True  # If there is no dev dataset, only use testset
        )

    if args.do_train:
        global_step, tr_loss = train(
            args, model, labels, train_dataset, dev_dataset, test_dataset
        )
        logger.info(" global_step = {}, average loss = {}".format(global_step, tr_loss))

    if args.do_eval and not args.do_nni:
        results = {}
        checkpoints = list(
            os.path.dirname(c)
            for c in sorted(
                glob.glob(
                    args.output_dir + "/**/" + "pytorch_model.bin", recursive=True
                )
            )
        )
        if not args.eval_all_checkpoints:
            checkpoints = checkpoints[-1:]
        else:
            logging.getLogger("transformers.configuration_utils").setLevel(
                logging.WARN
            )  # Reduce logging
            logging.getLogger("transformers.modeling_utils").setLevel(
                logging.WARN
            )  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1]
            if "naver-ner" in args.train_file:
                model = ElectraForTokenClassification.from_pretrained(checkpoint)
            else:
                model = ElectraForSequenceClassification.from_pretrained(checkpoint)
            model.to(args.device)
            result = evaluate(
                args,
                model,
                test_dataset,
                mode="test",
                labels=labels,
                global_step=global_step,
            )
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
            results.update(result)

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as f_w:
            for key in sorted(results.keys()):
                f_w.write("{} = {}\n".format(key, str(results[key])))
Exemplo n.º 20
0
 def _get_encoder(self):
     config = ElectraConfig.from_pretrained(self.backbone)
     q_encoder = ElectraEncoder(config=config).cuda()
     return q_encoder
Exemplo n.º 21
0
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
    tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')

# tf.config.optimizer.set_jit(USE_XLA)
if USE_AMP:  # params.use_amp:
    # tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
    policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
    tf.keras.mixed_precision.experimental.set_policy(policy)

# tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})

# Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
config = ElectraConfig.from_pretrained("google/electra-base-discriminator",
                                       num_labels=num_labels)
tokenizer = ElectraTokenizer.from_pretrained(
    "google/electra-base-discriminator")
model = TFElectraForTokenClassification.from_pretrained(
    "google/electra-base-discriminator", config=config)

# Load dataset via TensorFlow Datasets
# data, info = tensorflow_datasets.load(f"glue/{TFDS_TASK}", with_info=True)
# train_examples = info.splits["train"].num_examples

# MNLI expects either validation_matched or validation_mismatched
# valid_examples = info.splits[val_string].num_examples
# test_examples = info.splits[test_string].num_examples

##replace train and test examples
data_processor = glue.glue_processors[TASK]()
def chatbot_tag(diary):
    root_path = str(pathlib.Path(__file__).parent.absolute())
    checkpoint_path = f"{root_path}/checkpoint"
    save_ckpt_path = f"{checkpoint_path}/koelectra-wellness-text-classification.pth"
    model_name_or_path = "monologg/koelectra-base-discriminator"

    # 답변과 카테고리 불러오기
    category = []
    idx = -1
    # with open(root_path+'/data/wellness_data_for_text_classification.txt', 'r') as f:
    with open('..\data\wellness_data_for_text_classification.txt',
              'r',
              encoding="UTF-8") as f:
        while True:
            line = f.readline()
            if not line:
                break
            datas = line.strip().split("\t")
            if datas[1] != str(idx):
                category.append(datas[2])
                idx += 1

    ctx = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(ctx)

    # 저장한 Checkpoint 불러오기
    # checkpoint = torch.load(save_ckpt_path, map_location=device)
    checkpoint = torch.load(
        "../checkpoint/koelectra-wellness-text-classification.pth",
        map_location=device)

    # Electra Tokenizer
    tokenizer = ElectraTokenizer.from_pretrained(model_name_or_path)

    electra_config = ElectraConfig.from_pretrained(model_name_or_path)
    model = koElectraForSequenceClassification.from_pretrained(
        pretrained_model_name_or_path=model_name_or_path,
        config=electra_config,
        num_labels=359)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()

    sent = diary  # '요즘 기분이 우울한 느낌이에요'
    data = koelectra_input(tokenizer, sent, device, 512)
    # print(data)

    output = model(**data)

    logit = output
    softmax_logit = nn.Softmax(logit).dim
    softmax_logit = softmax_logit[0].squeeze()

    max_index = torch.argmax(softmax_logit).item()
    max_index_value = softmax_logit[torch.argmax(softmax_logit)].item()

    print(f'index: {category[max_index]}, value: {max_index_value}')
    print('-' * 50)

    emotion_tag = f'{category[max_index]}'

    # return jsonify({"emotion_tag": emotion_tag})
    return emotion_tag
Exemplo n.º 23
0
if __name__ == '__main__':
    config_path = './config/koelectra-small-v2.json'
    ckpt_path = './ckpt/koelectra-small-v2-mz-ckpt/checkpoint-11500/pytorch_model.bin'

    with open(config_path) as f:
        args = AttrDict(json.load(f))

    device = torch.device('cuda')

    checkpoint = torch.load(ckpt_path, map_location=device)

    tokenizer = ElectraTokenizer.from_pretrained(
        args.model_name_or_path,
        do_lower_case=args.do_lower_case
    )
    config = ElectraConfig.from_pretrained(args.model_name_or_path)
    model = koElectraForSequenceClassification.from_pretrained(
            args.model_name_or_path,
            config=config,
            num_labels=60
            )

    model.load_state_dict(checkpoint)
    model.to(device)
    model.eval()

    f = open('./final_test.csv', 'r', newline='') # f = open('final_test.csv', 'a', newline='')
    lines = csv.reader(f)
    next(lines)
    talks = []
    result = {}
Exemplo n.º 24
0
def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    set_seed(args.seed)
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s :: %(levelname)s :: %(message)s')

    if args.numnet_model is not None:
        config = BertConfig.from_pretrained(
            args.model_name, num_labels=1)  # 1 label for regression
        # if args.contrastive:
        #     model = ContrastiveElectra.from_pretrained(args.model_name, config=config)
        # else:
        model = BertForSequenceClassification.from_pretrained(args.model_name,
                                                              config=config)
        state_dicts = torch.load(args.numnet_model)
        if "model" in state_dicts:
            logging.info("Loading in mutual electra format state_dicts.")
            model.load_state_dict(state_dicts["model"], strict=False)
        else:
            logging.info("Loading model weights only.")
            model.load_state_dict(state_dicts, strict=False)
    else:
        config = ElectraConfig.from_pretrained(
            args.model_name, num_labels=1)  # 1 label for regression
        model = ElectraForSequenceClassification.from_pretrained(
            args.model_name, config=config)
        if args.local_model_path is not None:
            state_dicts = torch.load(args.local_model_path)
            model.load_state_dict(state_dicts["model"])

    tokenizer = ElectraTokenizer.from_pretrained(args.model_name,
                                                 do_lower_case=True)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    # TODO enable multi-gpu training if necessary
    pretrain_train_dataset = DapoDataset(args.data_dir, "train",
                                         tokenizer) if args.pretrain else None
    pretrain_dev_dataset = DapoDataset(args.data_dir, "dev",
                                       tokenizer) if args.pretrain else None

    if args.train:
        if args.contrastive:
            train_dataset = ContrastiveDataset(args.data_dir, "train",
                                               tokenizer)
            train_dataloader = DataLoader(train_dataset,
                                          batch_size=args.train_batch_size,
                                          shuffle=False,
                                          num_workers=8,
                                          collate_fn=mutual_contrast_collate)
            dev_dataset = ContrastiveDataset(
                args.data_dir, "dev",
                tokenizer) if args.eval or args.test else None
            dev_dataloader = DataLoader(dev_dataset,
                                        batch_size=args.train_batch_size,
                                        shuffle=False,
                                        num_workers=8,
                                        collate_fn=mutual_contrast_collate
                                        ) if dev_dataset is not None else None
        else:
            train_dataset = MutualDataset(args.data_dir, "train", tokenizer)
            train_dataloader = DataLoader(train_dataset,
                                          batch_size=args.train_batch_size,
                                          shuffle=True,
                                          num_workers=8,
                                          collate_fn=mutual_collate)
            dev_dataset = MutualDataset(
                args.data_dir, "dev",
                tokenizer) if args.eval or args.test else None
            dev_dataloader = DataLoader(
                dev_dataset,
                batch_size=args.train_batch_size,
                shuffle=False,
                num_workers=8,
                collate_fn=mutual_collate) if dev_dataset is not None else None

    else:
        train_dataset, train_dataloader = None, None

    # TODO: add test_dataset if we want to submit to leaderboard

    pretrain_train_dataloader = DataLoader(
        pretrain_train_dataset,
        batch_size=args.train_batch_size,
        shuffle=True,
        num_workers=8,
        collate_fn=dapo_collate
    ) if pretrain_train_dataset is not None else None
    pretrain_dev_dataloader = DataLoader(
        pretrain_dev_dataset,
        batch_size=args.train_batch_size,
        shuffle=False,
        num_workers=8,
        collate_fn=dapo_collate) if pretrain_dev_dataset is not None else None

    # currently eval_batch_size = train_batch_size

    if args.pretrain:
        logging.info("Start pretraining...")
        args.eval = True
        trainer = Trainer(args, model, device, pretrain_train_dataloader,
                          pretrain_dev_dataloader)
        trainer.train()
        return  # fine-tuning should be done separately

    if args.train:
        logging.info("Start training...")
        trainer = Trainer(args, model, device, train_dataloader,
                          dev_dataloader)
        trainer.train()

    # TODO: currently testing is on the dev set
    if args.test:
        logging.info("Start testing...")
        tester = Tester(args, model, device, dev_dataset, dev_dataloader)
        tester.test()
Exemplo n.º 25
0
assert c.schedule in ['original_linear', 'separate_linear', 'one_cycle', 'adjusted_one_cycle']
if not c.base_run_name: c.base_run_name = str(datetime.now(timezone(timedelta(hours=+8))))[6:-13].replace(' ','').replace(':','').replace('-','')
if not c.seed: c.seed = random.randint(0, 99999)
c.run_name = f'{c.base_run_name}_{c.seed}'
if c.gen_smooth_label is True: c.gen_smooth_label = 0.1
if c.disc_smooth_label is True: c.disc_smooth_label = 0.1

# Setting of different sizes
i = ['small', 'base', 'large'].index(c.size)
c.mask_prob = [0.15, 0.15, 0.25][i]
c.lr = [5e-4, 2e-4, 2e-4][i]
c.bs = [128, 256, 2048][i]
c.steps = [10**6, 766*1000, 400*1000][i]
c.max_length = [128, 512, 512][i]
generator_size_divisor = [4, 3, 4][i]
disc_config = ElectraConfig.from_pretrained(f'google/electra-{c.size}-discriminator')
gen_config = ElectraConfig.from_pretrained(f'google/electra-{c.size}-generator')
# note that public electra-small model is actually small++ and don't scale down generator size 
gen_config.hidden_size = int(disc_config.hidden_size/generator_size_divisor)
gen_config.num_attention_heads = int(disc_config.num_attention_heads/generator_size_divisor)
gen_config.intermediate_size = int(disc_config.intermediate_size/generator_size_divisor)
hf_tokenizer = ElectraTokenizerFast.from_pretrained(f"google/electra-{c.size}-generator")

# Path to data
Path('./datasets', exist_ok=True)
Path('./checkpoints/pretrain').mkdir(exist_ok=True, parents=True)
if c.size in ['small', 'base']:
  wiki_cache_dir = Path("./datasets/wikipedia/20200501.en/1.0.0")
  book_cache_dir = Path("./datasets/bookcorpus/plain_text/1.0.0")
  wbdl_cache_dir = Path("./datasets/wikibook_dl")
  wbdl_cache_dir.mkdir(exist_ok=True)
Exemplo n.º 26
0
def train():
    # load model and tokenizer
    #MODEL_NAME = "bert-base-multilingual-cased"
    MODEL_NAME = "monologg/koelectra-base-v3-discriminator"
    tokenizer = ElectraTokenizer.from_pretrained(MODEL_NAME)
    print(tokenizer.tokenize("이순신은 조선 중기의 무신이다."))
    print(tokenizer.tokenize("아버지가방에들어가신다."))
    tokenized_str = tokenizer.tokenize("이순신은 조선 중기의 무신이다." +
                                       tokenizer.sep_token + "아버지가방에들어가신다.")
    print(tokenized_str)

    # load dataset
    train_dataset = load_data("/opt/ml/input/data/train/train.tsv")
    train_label = train_dataset['label'].values

    # tokenizing dataset
    tokenized_train = tokenized_dataset(train_dataset, tokenizer)

    # make dataset for pytorch.
    RE_train_dataset = RE_Dataset(tokenized_train, train_label)
    train_dataset, dev_dataset = torch.utils.data.random_split(
        RE_train_dataset, [7000, 2001])

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # setting model hyperparameter
    bert_config = ElectraConfig.from_pretrained(MODEL_NAME)
    bert_config.num_labels = 42
    model = ElectraForSequenceClassification.from_pretrained(
        MODEL_NAME, config=bert_config)
    #model.parameters
    model.to(device)

    # 사용한 option 외에도 다양한 option들이 있습니다.
    # https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments 참고해주세요.
    training_args = TrainingArguments(
        output_dir='./results',  # output directory
        save_total_limit=4,  # number of total save model.
        load_best_model_at_end=True,
        save_steps=100,  # model saving step.
        num_train_epochs=10,  # total number of training epochs
        learning_rate=5e-5,  # learning_rate
        per_device_train_batch_size=8,  # batch size per device during training
        per_device_eval_batch_size=8,  # batch size for evaluation
        warmup_steps=500,  # number of warmup steps for learning rate scheduler
        weight_decay=0.01,  # strength of weight decay
        logging_dir='./logs',  # directory for storing logs
        logging_steps=100,  # log saving step.
        evaluation_strategy=
        'steps',  # evaluation strategy to adopt during training
        # `no`: No evaluation during training.
        # `steps`: Evaluate every `eval_steps`.
        # `epoch`: Evaluate every end of epoch.
        eval_steps=100,  # evaluation step.
        dataloader_num_workers=3,
        label_smoothing_factor=0.5)
    trainer = Trainer(
        model=model,  # the instantiated 🤗 Transformers model to be trained
        args=training_args,  # training arguments, defined above
        train_dataset=train_dataset,  # training dataset
        eval_dataset=dev_dataset,  # evaluation dataset
        compute_metrics=compute_metrics,  # define metrics function
    )

    # train model
    trainer.train()
Exemplo n.º 27
0
                        help="Number of gpus to use for distributed training.")
    parser.add_argument("--output_dir", default="ckpts")
    parser.add_argument("--local_model_path", default=None, type=str)
    parser.add_argument("--numnet_model", default=None, type=str)
    parser.add_argument("--constrasitve", action="store_true")
    parser.add_argument("--speaker_embeddings", action="store_true")

    return parser.parse_args()


args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s :: %(levelname)s :: %(message)s')

config = ElectraConfig.from_pretrained(args.model_name,
                                       num_labels=1)  # 1 label for regression
tokenizer = ElectraTokenizer.from_pretrained(args.model_name,
                                             do_lower_case=True)

model = SpeakerAwareElectraModelForSequenceClassification(
    args.model_name, config, 2)

model.load_state_dict(
    torch.load(args.checkpoint, map_location=torch.device('cpu'))['model'])
model.eval()

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

dev_dataset = MutualDataset('mutual/', "dev", tokenizer)
dev_dataloader = DataLoader(dev_dataset,
Exemplo n.º 28
0
    def __init__(
        self,
        model_type,
        model_name,
        generator_name=None,
        discriminator_name=None,
        train_files=None,
        args=None,
        use_cuda=True,
        cuda_device=-1,
        **kwargs,
    ):

        """
        Initializes a LanguageModelingModel.

        Args:
            model_type: The type of model (gpt2, openai-gpt, bert, roberta, distilbert, camembert)
            model_name: Default Transformer model name or path to a directory containing Transformer model file (pytorch_nodel.bin).
            generator_name (optional): A pretrained model name or path to a directory containing an ELECTRA generator model.
            discriminator_name (optional): A pretrained model name or path to a directory containing an ELECTRA discriminator model.
            args (optional): Default args will be used if this parameter is not provided. If provided, it should be a dict containing the args that should be changed in the default args.
            train_files (optional): List of files to be used when training the tokenizer.
            use_cuda (optional): Use GPU if available. Setting to False will force model to use CPU only.
            cuda_device (optional): Specific GPU that should be used. Will use the first available GPU by default.
            **kwargs (optional): For providing proxies, force_download, resume_download, cache_dir and other options specific to the 'from_pretrained' implementation where this will be supplied.
        """  # noqa: ignore flake8"

        if args and "manual_seed" in args:
            random.seed(args["manual_seed"])
            np.random.seed(args["manual_seed"])
            torch.manual_seed(args["manual_seed"])
            if "n_gpu" in args and args["n_gpu"] > 0:
                torch.cuda.manual_seed_all(args["manual_seed"])

        if use_cuda:
            if torch.cuda.is_available():
                if cuda_device == -1:
                    self.device = torch.device("cuda")
                else:
                    self.device = torch.device(f"cuda:{cuda_device}")
            else:
                raise ValueError(
                    "'use_cuda' set to True when cuda is unavailable."
                    " Make sure CUDA is available or set use_cuda=False."
                )
        else:
            self.device = "cpu"

        self.results = {}

        self.args = {
            "dataset_type": "None",
            "dataset_class": None,
            "custom_tokenizer": None,
            "block_size": -1,
            "mlm": True,
            "mlm_probability": 0.15,
            "max_steps": -1,
            "config_name": None,
            "tokenizer_name": None,
            "min_frequency": 2,
            "special_tokens": ["<s>", "<pad>", "</s>", "<unk>", "<mask>"],
            "sliding_window": False,
            "stride": 0.8,
            "generator_config": {},
            "discriminator_config": {},
            "vocab_size": None,
        }

        self.args.update(global_args)

        if not use_cuda:
            self.args["fp16"] = False

        if args:
            self.args.update(args)

        self.args["model_name"] = model_name
        self.args["model_type"] = model_type

        config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type]
        self.tokenizer_class = tokenizer_class
        new_tokenizer = False

        if self.args["tokenizer_name"]:
            self.tokenizer = tokenizer_class.from_pretrained(
                self.args["tokenizer_name"], cache_dir=self.args["cache_dir"]
            )
        elif self.args["model_name"]:
            self.tokenizer = tokenizer_class.from_pretrained(model_name, cache_dir=self.args["cache_dir"], **kwargs)
            self.args["tokenizer_name"] = self.args["model_name"]
        else:
            if not train_files:
                raise ValueError(
                    "model_name and tokenizer_name are not specified."
                    "You must specify train_files to train a Tokenizer."
                )
            else:
                self.train_tokenizer(train_files)
                new_tokenizer = True

        if self.args["config_name"]:
            self.config = config_class.from_pretrained(self.args["config_name"], cache_dir=self.args["cache_dir"])
        elif self.args["model_name"]:
            self.config = config_class.from_pretrained(model_name, cache_dir=self.args["cache_dir"], **kwargs)
        else:
            self.config = config_class(**self.args["config"], **kwargs)
        if self.args["vocab_size"]:
            self.config.vocab_size = self.args["vocab_size"]
        if new_tokenizer:
            self.config.vocab_size = len(self.tokenizer)

        if self.args["model_type"] == "electra":
            if generator_name:
                self.generator_config = ElectraConfig.from_pretrained(generator_name)
            elif self.args["model_name"]:
                self.generator_config = ElectraConfig.from_pretrained(
                    os.path.join(self.args["model_name"], "generator_config"), **kwargs,
                )
            else:
                self.generator_config = ElectraConfig(**self.args["generator_config"], **kwargs)
                if new_tokenizer:
                    self.generator_config.vocab_size = len(self.tokenizer)

            if discriminator_name:
                self.discriminator_config = ElectraConfig.from_pretrained(discriminator_name)
            elif self.args["model_name"]:
                self.discriminator_config = ElectraConfig.from_pretrained(
                    os.path.join(self.args["model_name"], "discriminator_config"), **kwargs,
                )
            else:
                self.discriminator_config = ElectraConfig(**self.args["discriminator_config"], **kwargs)
                if new_tokenizer:
                    self.discriminator_config.vocab_size = len(self.tokenizer)

        if self.args["block_size"] <= 0:
            self.args["block_size"] = min(self.args["max_seq_length"], self.tokenizer.max_len)
        else:
            self.args["block_size"] = min(self.args["block_size"], self.tokenizer.max_len, self.args["max_seq_length"])

        if self.args["model_name"]:
            if self.args["model_type"] == "electra":
                self.model = model_class.from_pretrained(
                    model_name,
                    config=self.config,
                    cache_dir=self.args["cache_dir"],
                    generator_config=self.generator_config,
                    discriminator_config=self.discriminator_config,
                    **kwargs,
                )
                self.model.load_state_dict(torch.load(os.path.join(self.args["model_name"], "pytorch_model.bin")))
            else:
                self.model = model_class.from_pretrained(
                    model_name, config=self.config, cache_dir=self.args["cache_dir"], **kwargs,
                )
        else:
            logger.info(" Training language model from scratch")
            if self.args["model_type"] == "electra":
                generator_model = ElectraForMaskedLM(config=self.generator_config)
                discriminator_model = ElectraForPreTraining(config=self.discriminator_config)
                self.model = ElectraForLanguageModelingModel(
                    config=self.config,
                    generator_model=generator_model,
                    discriminator_model=discriminator_model,
                    generator_config=self.generator_config,
                    discriminator_config=self.discriminator_config,
                )
                model_to_resize = (
                    self.model.generator_model.module
                    if hasattr(self.model.generator_model, "module")
                    else self.model.generator_model
                )
                model_to_resize.resize_token_embeddings(len(self.tokenizer))

                model_to_resize = (
                    self.model.discriminator_model.module
                    if hasattr(self.model.discriminator_model, "module")
                    else self.model.discriminator_model
                )
                model_to_resize.resize_token_embeddings(len(self.tokenizer))
            else:
                self.model = model_class(config=self.config)
                model_to_resize = self.model.module if hasattr(self.model, "module") else self.model
                model_to_resize.resize_token_embeddings(len(self.tokenizer))

        if model_type in ["camembert", "xlmroberta"]:
            warnings.warn(
                f"use_multiprocessing automatically disabled as {model_type}"
                " fails when using multiprocessing for feature conversion."
            )
            self.args["use_multiprocessing"] = False

        if self.args["wandb_project"] and not wandb_available:
            warnings.warn("wandb_project specified but wandb is not available. Wandb disabled.")
            self.args["wandb_project"] = None
Exemplo n.º 29
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected")

    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected")

    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written.",
    )

    # Other parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        help="The input data dir. Should contain the .json files for the task."
        +
        "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--train_file",
        default=None,
        type=str,
        help=
        "The input training file. If a data dir is specified, will look for the file there"
        +
        "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help=
        "The input evaluation file. If a data dir is specified, will look for the file there"
        +
        "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3",
    )

    parser.add_argument(
        "--version_2_with_negative",
        action="store_true",
        help=
        "If true, the SQuAD examples contain some that do not have an answer.",
    )
    parser.add_argument(
        "--null_score_diff_threshold",
        type=float,
        default=0.0,
        help=
        "If null_score - best_non_null is greater than the threshold predict null.",
    )

    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded.",
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks.",
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.",
    )
    parser.add_argument("--do_train",
                        action="store_true",
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action="store_true",
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training",
        default=True,
        action="store_true",
        help="Run evaluation during training at each logging step.")
    parser.add_argument(
        "--do_lower_case",
        action="store_true",
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json output file.",
    )
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.",
    )
    parser.add_argument(
        "--verbose_logging",
        action="store_true",
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.",
    )

    parser.add_argument("--logging_steps",
                        type=int,
                        default=100,
                        help="Log every X updates steps.")
    parser.add_argument("--save_steps",
                        type=int,
                        default=10000,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
    )
    parser.add_argument("--no_cuda",
                        action="store_true",
                        help="Whether not to use CUDA when available")
    parser.add_argument("--overwrite_output_dir",
                        action="store_true",
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        "--overwrite_cache",
        action="store_true",
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument("--seed",
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        "--fp16",
        action="store_true",
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--server_ip",
                        type=str,
                        default="",
                        help="Can be used for distant debugging.")
    parser.add_argument("--server_port",
                        type=str,
                        default="",
                        help="Can be used for distant debugging.")

    parser.add_argument(
        "--threads",
        type=int,
        default=1,
        help="multiple threads for converting example to features")

    ### DO NOT MODIFY THIS BLOCK ###
    # arguments for nsml
    parser.add_argument('--pause', type=int, default=0)
    parser.add_argument('--mode', type=str, default='train')
    ################################

    args = parser.parse_args()

    # for NSML
    args.data_dir = os.path.join(DATASET_PATH, args.data_dir)

    if (os.path.exists(args.output_dir) and os.listdir(args.output_dir)
            and args.do_train and not args.overwrite_output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.output_dir))

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
        logger.warning('IF args.n_gpu : ' + str(args.n_gpu) + ' / device : ' +
                       str(device) + '\n')
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
        logger.warning('ELSE args.n_gpu : ' + str(args.n_gpu) +
                       ' / device : ' + str(device) + '\n')

    args.device = device

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
        filename='log.log')
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    logger.warning("Model Loading ..")

    config = ElectraConfig.from_pretrained(args.model_name_or_path)
    model = ElectraForQuestionAnswering.from_pretrained(
        args.model_name_or_path, config=config)
    tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path,
                                                 do_lower_case=False)

    logger.warning("Model Loading Completed")

    if args.local_rank == 0:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    model.to(args.device)

    ### DO NOT MODIFY THIS BLOCK ###
    if IS_ON_NSML:
        bind_nsml(model, tokenizer, args)
        if args.pause:
            nsml.paused(scope=locals())
    ################################

    logger.info("Training/evaluation parameters %s", args)

    # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is
    # set. Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running
    # `--fp16_opt_level="O2"` will remove the need for this code, but it is still valid.
    if args.fp16:
        try:
            import apex

            apex.amp.register_half_function(torch, "einsum")
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

    # Training
    if args.do_train:
        train_dataset = load_and_cache_examples(args,
                                                tokenizer,
                                                evaluate=False,
                                                output_examples=False)
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step,
                    tr_loss)
Exemplo n.º 30
0
def main():
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments, LoggingArguments, PathArguments)
    )
    (
        model_args,
        data_args,
        train_args,
        log_args,
        path_args,
        remaining_strings,
    ) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
    # SageMaker may have some extra strings. TODO: Test this on SM.
    assert len(remaining_strings) == 0, f"The args {remaining_strings} could not be parsed."

    hvd.init()
    gpus = tf.config.list_physical_devices("GPU")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU")
    if train_args.eager == "true":
        tf.config.experimental_run_functions_eagerly(True)

    tokenizer = ElectraTokenizerFast.from_pretrained("bert-base-uncased")

    gen_config = ElectraConfig.from_pretrained(f"google/electra-{model_args.model_size}-generator")
    dis_config = ElectraConfig.from_pretrained(
        f"google/electra-{model_args.model_size}-discriminator"
    )

    gen = TFElectraForMaskedLM(config=gen_config)
    dis = TFElectraForPreTraining(config=dis_config)
    optimizer = get_adamw_optimizer(train_args)

    # Tie the weights
    if model_args.electra_tie_weights == "true":
        gen.electra.embeddings = dis.electra.embeddings

    loaded_optimizer_weights = None
    if model_args.load_from == "checkpoint":
        checkpoint_path = os.path.join(path_args.filesystem_prefix, model_args.checkpoint_path)
        dis_ckpt, gen_ckpt, optimizer_ckpt = get_checkpoint_paths_from_prefix(checkpoint_path)
        if hvd.rank() == 0:
            dis.load_weights(dis_ckpt)
            gen.load_weights(gen_ckpt)
            loaded_optimizer_weights = np.load(optimizer_ckpt, allow_pickle=True)

    start_time = time.perf_counter()

    if hvd.rank() == 0:
        # Logging should only happen on a single process
        # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time
        level = logging.INFO
        format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s"
        handlers = [
            TqdmLoggingHandler(),
        ]
        summary_writer = None  # Only create a writer if we make it through a successful step
        logging.basicConfig(level=level, format=format, handlers=handlers)
        wandb_run_name = None

        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        if log_args.run_name is None:
            metadata = (
                f"electra-{hvd.size()}gpus"
                f"-{train_args.per_gpu_batch_size * hvd.size() * train_args.gradient_accumulation_steps}globalbatch"
                f"-{train_args.total_steps}steps"
            )
            run_name = (
                f"{current_time}-{metadata}-{train_args.name if train_args.name else 'unnamed'}"
            )
        else:
            run_name = log_args.run_name

    logger.info(f"Training with dataset at {path_args.train_dir}")
    logger.info(f"Validating with dataset at {path_args.val_dir}")

    train_glob = os.path.join(path_args.filesystem_prefix, path_args.train_dir, "*.tfrecord*")
    validation_glob = os.path.join(path_args.filesystem_prefix, path_args.val_dir, "*.tfrecord*")

    train_filenames = glob.glob(train_glob)
    validation_filenames = glob.glob(validation_glob)
    logger.info(
        f"Number of train files {len(train_filenames)}, number of validation files {len(validation_filenames)}"
    )

    tf_train_dataset = get_dataset_from_tfrecords(
        model_type=model_args.model_type,
        filenames=train_filenames,
        per_gpu_batch_size=train_args.per_gpu_batch_size,
        max_seq_length=data_args.max_seq_length,
    )

    tf_train_dataset = tf_train_dataset.prefetch(buffer_size=8)

    if hvd.rank() == 0:
        tf_val_dataset = get_dataset_from_tfrecords(
            model_type=model_args.model_type,
            filenames=validation_filenames,
            per_gpu_batch_size=train_args.per_gpu_batch_size,
            max_seq_length=data_args.max_seq_length,
        )
        tf_val_dataset = tf_val_dataset.prefetch(buffer_size=8)

    wandb_run_name = None

    step = 1
    for batch in tf_train_dataset:
        learning_rate = optimizer.learning_rate(step=tf.constant(step, dtype=tf.float32))
        ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        train_result = train_step(
            optimizer=optimizer,
            gen=gen,
            dis=dis,
            ids=ids,
            attention_mask=attention_mask,
            mask_token_id=tokenizer.mask_token_id,
        )

        if step == 1:
            # Horovod broadcast
            if hvd.rank() == 0 and loaded_optimizer_weights is not None:
                optimizer.set_weights(loaded_optimizer_weights)
            hvd.broadcast_variables(gen.variables, root_rank=0)
            hvd.broadcast_variables(dis.variables, root_rank=0)
            hvd.broadcast_variables(optimizer.variables(), root_rank=0)
            step = optimizer.get_weights()[0]

        is_final_step = step >= train_args.total_steps
        if hvd.rank() == 0:
            do_log = step % log_args.log_frequency == 0
            do_checkpoint = (step > 1) and (
                (step % log_args.checkpoint_frequency == 0) or is_final_step
            )
            do_validation = step % log_args.validation_frequency == 0

            if do_log:
                elapsed_time = time.perf_counter() - start_time  # Off for first log
                it_s = log_args.log_frequency / elapsed_time
                start_time = time.perf_counter()
                description = f"Step {step} -- gen_loss: {train_result.gen_loss:.3f}, dis_loss: {train_result.dis_loss:.3f}, gen_acc: {train_result.gen_acc:.3f}, dis_acc: {train_result.dis_acc:.3f}, it/s: {it_s:.3f}\n"
                logger.info(description)

            if do_validation:
                for batch in tf_val_dataset.take(1):
                    val_ids = batch["input_ids"]
                    val_attention_mask = batch["attention_mask"]
                    val_result = val_step(
                        gen=gen,
                        dis=dis,
                        ids=val_ids,
                        attention_mask=val_attention_mask,
                        mask_token_id=tokenizer.mask_token_id,
                    )
                    log_example(
                        tokenizer,
                        val_ids,
                        val_result.masked_ids,
                        val_result.corruption_mask,
                        val_result.gen_ids,
                        val_result.dis_preds,
                    )
                    description = f"VALIDATION, Step {step} -- val_gen_loss: {val_result.gen_loss:.3f}, val_dis_loss: {val_result.dis_loss:.3f}, val_gen_acc: {val_result.gen_acc:.3f}, val_dis_acc: {val_result.dis_acc:.3f}\n"
                    logger.info(description)

            train_metrics = {
                "learning_rate": learning_rate,
                "train/loss": train_result.loss,
                "train/gen_loss": train_result.gen_loss,
                "train/dis_loss": train_result.dis_loss,
                "train/gen_acc": train_result.gen_acc,
                "train/dis_acc": train_result.dis_acc,
            }
            all_metrics = {**train_metrics}
            if do_validation:
                val_metrics = {
                    "val/loss": val_result.loss,
                    "val/gen_loss": val_result.gen_loss,
                    "val/dis_loss": val_result.dis_loss,
                    "val/gen_acc": val_result.gen_acc,
                    "val/dis_acc": val_result.dis_acc,
                }
                all_metrics = {**all_metrics, **val_metrics}
            if do_log:
                all_metrics = {"it_s": it_s, **all_metrics}

            if is_wandb_available():
                if wandb_run_name is None:
                    config = {
                        **asdict(model_args),
                        **asdict(data_args),
                        **asdict(train_args),
                        **asdict(log_args),
                        **asdict(path_args),
                        "global_batch_size": train_args.per_gpu_batch_size * hvd.size(),
                        "n_gpus": hvd.size(),
                    }
                    wandb.init(config=config, project="electra")
                    wandb.run.save()
                    wandb_run_name = wandb.run.name
                wandb.log({"step": step, **all_metrics})

                # Create summary_writer after the first step
            if summary_writer is None:
                summary_writer = tf.summary.create_file_writer(
                    os.path.join(path_args.filesystem_prefix, path_args.log_dir, run_name)
                )
                config = {
                    **asdict(model_args),
                    **asdict(data_args),
                    **asdict(train_args),
                    **asdict(log_args),
                    **asdict(path_args),
                    "global_batch_size": train_args.per_gpu_batch_size * hvd.size(),
                    "n_gpus": hvd.size(),
                }

            # Log to TensorBoard
            with summary_writer.as_default():
                for name, val in all_metrics.items():
                    tf.summary.scalar(name, val, step=step)

            if do_checkpoint:
                dis_model_ckpt = os.path.join(
                    path_args.filesystem_prefix,
                    path_args.checkpoint_dir,
                    f"{run_name}-step{step}-discriminator.ckpt",
                )
                gen_model_ckpt = os.path.join(
                    path_args.filesystem_prefix,
                    path_args.checkpoint_dir,
                    f"{run_name}-step{step}-generator.ckpt",
                )
                optimizer_ckpt = os.path.join(
                    path_args.filesystem_prefix,
                    path_args.checkpoint_dir,
                    f"{run_name}-step{step}-optimizer.npy",
                )
                logger.info(
                    f"Saving discriminator model at {dis_model_ckpt}, generator model at {gen_model_ckpt}, optimizer at {optimizer_ckpt}"
                )
                dis.save_weights(dis_model_ckpt)
                gen.save_weights(gen_model_ckpt)
                np.save(optimizer_ckpt, optimizer.get_weights())

        step += 1
        if is_final_step:
            break