예제 #1
0
    def test_push_to_hub_in_organization(self):
        config = BertConfig(vocab_size=99,
                            hidden_size=32,
                            num_hidden_layers=5,
                            num_attention_heads=4,
                            intermediate_size=37)
        config.push_to_hub("valid_org/test-config-org",
                           use_auth_token=self._token)

        new_config = BertConfig.from_pretrained("valid_org/test-config-org")
        for k, v in config.to_dict().items():
            if k != "transformers_version":
                self.assertEqual(v, getattr(new_config, k))

        # Reset repo
        delete_repo(token=self._token, repo_id="valid_org/test-config-org")

        # Push to hub via save_pretrained
        with tempfile.TemporaryDirectory() as tmp_dir:
            config.save_pretrained(tmp_dir,
                                   repo_id="valid_org/test-config-org",
                                   push_to_hub=True,
                                   use_auth_token=self._token)

        new_config = BertConfig.from_pretrained("valid_org/test-config-org")
        for k, v in config.to_dict().items():
            if k != "transformers_version":
                self.assertEqual(v, getattr(new_config, k))
예제 #2
0
    def test_push_to_hub(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )
        with tempfile.TemporaryDirectory() as tmp_dir:
            config.save_pretrained(os.path.join(tmp_dir, "test-config"), push_to_hub=True, use_auth_token=self._token)

            new_config = BertConfig.from_pretrained(f"{USER}/test-config")
            for k, v in config.__dict__.items():
                if k != "transformers_version":
                    self.assertEqual(v, getattr(new_config, k))
예제 #3
0
def main():
    seed_everyone(20210318)

    raw_data_path = '../user_data/duality_pair_pretrain_no_nsp.tsv'
    output_dir = '../user_data/tmp_data/pretrain_output/whole_word_mask_bert_output'

    tokenizer = BertTokenizer.from_pretrained('../user_data/vocab_.txt')
    data = read_data(raw_data_path, tokenizer, debug=True)

    train_dataset = TcDataset(data)

    config = BertConfig(
        vocab_size=tokenizer.vocab_size,
        max_position_embeddings=100,
        type_vocab_size=2,
        pad_token_id=0,
    )
    model = BertForMaskedLM(config=config)
    wandb.init(project=f"bert_oppo_pretrain1", entity="zjw", dir=output_dir)
    data_collator = Data_Collator(max_seq_len=42,
                                  tokenizer=tokenizer,
                                  mlm_p=0.15)

    model_save_dir = os.path.join(output_dir, 'best_model_ckpt')
    tokenizer_and_config = os.path.join(output_dir, 'tokenizer_and_config')
    check_dir(model_save_dir)
    check_dir(tokenizer_and_config)

    training_args = TrainingArguments(output_dir=output_dir,
                                      overwrite_output_dir=True,
                                      num_train_epochs=100,
                                      fp16_backend='auto',
                                      per_device_train_batch_size=128,
                                      save_steps=500,
                                      logging_steps=500,
                                      save_total_limit=10,
                                      prediction_loss_only=True,
                                      run_name='0419',
                                      logging_first_step=True,
                                      dataloader_num_workers=4,
                                      disable_tqdm=False,
                                      seed=202104)

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
    )

    trainer.train()
    trainer.save_model(model_save_dir)
    config.save_pretrained(tokenizer_and_config)
    tokenizer.save_pretrained(tokenizer_and_config)
    def test_push_to_hub_in_organization(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )

        with tempfile.TemporaryDirectory() as tmp_dir:
            config.save_pretrained(
                tmp_dir,
                push_to_hub=True,
                repo_name="test-model-org",
                use_auth_token=self._token,
                organization="valid_org",
            )

            new_config = BertConfig.from_pretrained("valid_org/test-model-org")
            for k, v in config.__dict__.items():
                if k != "transformers_version":
                    self.assertEqual(v, getattr(new_config, k))