Example #1
0
    def test_advanced_inputs(self):
        # (config, input_ids, token_type_ids, input_mask, *unused) = \
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )
        decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(
            config, inputs_dict["input_ids"])
        model = BartModel(config)
        model.to(torch_device)
        model.eval()
        # test init
        self.assertTrue(
            (model.encoder.embed_tokens.weight == model.shared.weight
             ).all().item())

        def _check_var(module):
            """Check that we initialized various parameters from N(0, config.init_std)."""
            self.assertAlmostEqual(
                torch.std(module.weight).item(), config.init_std, 2)

        _check_var(model.encoder.embed_tokens)
        _check_var(model.encoder.layers[0].self_attn.k_proj)
        _check_var(model.encoder.layers[0].fc1)
        _check_var(model.encoder.embed_positions)

        decoder_features_with_created_mask = model.forward(**inputs_dict)[0]
        decoder_features_with_passed_mask = model.forward(
            decoder_attention_mask=decoder_attn_mask,
            decoder_input_ids=decoder_input_ids,
            **inputs_dict)[0]
        _assert_tensors_equal(decoder_features_with_passed_mask,
                              decoder_features_with_created_mask)
        useless_mask = torch.zeros_like(decoder_attn_mask)
        decoder_features = model.forward(decoder_attention_mask=useless_mask,
                                         **inputs_dict)[0]
        self.assertTrue(isinstance(
            decoder_features, torch.Tensor))  # no hidden states or attentions
        self.assertEqual(decoder_features.size(),
                         (self.model_tester.batch_size,
                          self.model_tester.seq_length, config.d_model))
        if decoder_attn_mask.min().item() < -1e3:  # some tokens were masked
            self.assertFalse(
                (decoder_features_with_created_mask == decoder_features
                 ).all().item())

        # Test different encoder attention masks
        decoder_features_with_long_encoder_mask = model.forward(
            inputs_dict["input_ids"],
            attention_mask=inputs_dict["attention_mask"].long())[0]
        _assert_tensors_equal(decoder_features_with_long_encoder_mask,
                              decoder_features_with_created_mask)
Example #2
0
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
    """
    Copy/paste/tweak model's weights to our BERT structure.
    """
    bart = torch.hub.load("pytorch/fairseq", checkpoint_path)
    bart.eval()  # disable dropout
    bart.model.upgrade_state_dict(bart.model.state_dict())
    hf_model_name = checkpoint_path.replace(".", "-")
    config = BartConfig.from_pretrained(hf_model_name)
    tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
    tokens2 = BartTokenizer.from_pretrained(hf_model_name).encode(
        SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
    assert torch.eq(tokens, tokens2).all()

    if checkpoint_path in ["bart.large", "bart.large.cnn"]:
        state_dict = bart.model.state_dict()
        for k in IGNORE_KEYS:
            state_dict.pop(k, None)
        state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
        model = BartModel(config)
        their_output = bart.extract_features(tokens)
    else:  # MNLI Case
        state_dict = bart.state_dict()
        for k in IGNORE_KEYS:
            state_dict.pop(k, None)
        state_dict["model.shared.weight"] = state_dict[
            "model.decoder.embed_tokens.weight"]
        for src, dest in rename_keys:
            rename_key(state_dict, src, dest)
        model = BartForSequenceClassification(config)
        their_output = bart.eval("mnli", tokens, return_logits=True)

    # Load state dict
    model.load_state_dict(state_dict)
    model.eval()
    # Check results

    if checkpoint_path == "bart.large.cnn":  # generate doesnt work yet
        model = BartForMaskedLM(config, base_model=model)
        assert "lm_head.weight" in model.state_dict()
        assert model.lm_head.out_features == config.max_position_embeddings
        model.eval()
        our_outputs = model.model.forward(tokens)[0]
    else:
        our_outputs = model.forward(tokens)[0]
    assert their_output.shape == our_outputs.shape
    assert (their_output == our_outputs).all().item()
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
    model.save_pretrained(pytorch_dump_folder_path)
Example #3
0
    def test_initialization_more(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        model = BartModel(config)
        model.to(torch_device)
        model.eval()
        # test init
        self.assertTrue((model.encoder.embed_tokens.weight == model.shared.weight).all().item())

        def _check_var(module):
            """Check that we initialized various parameters from N(0, config.init_std)."""
            self.assertAlmostEqual(torch.std(module.weight).item(), config.init_std, 2)

        _check_var(model.encoder.embed_tokens)
        _check_var(model.encoder.layers[0].self_attn.k_proj)
        _check_var(model.encoder.layers[0].fc1)
        _check_var(model.encoder.embed_positions)
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
    """
    Copy/paste/tweak model's weights to our BERT structure.
    """
    b2 = torch.hub.load("pytorch/fairseq", checkpoint_path)
    b2.eval()  # disable dropout
    b2.model.upgrade_state_dict(b2.model.state_dict())
    config = BartConfig()
    tokens = b2.encode(SAMPLE_TEXT).unsqueeze(0)
    tokens2 = BartTokenizer.from_pretrained("bart-large").encode(
        SAMPLE_TEXT).unsqueeze(0)
    assert torch.eq(tokens, tokens2).all()

    # assert their_output.size() == (1, 11, 1024)

    if checkpoint_path == "bart.large":
        state_dict = b2.model.state_dict()
        state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
        model = BartModel(config)
        their_output = b2.extract_features(tokens)

    else:  # MNLI Case
        state_dict = b2.state_dict()
        state_dict["model.shared.weight"] = state_dict[
            "model.decoder.embed_tokens.weight"]
        for src, dest in rename_keys:
            rename_key(state_dict, src, dest)
        state_dict.pop("_float_tensor", None)
        model = BartForSequenceClassification(config)
        their_output = b2.predict("mnli", tokens, return_logits=True)
    for k in IGNORE_KEYS:
        state_dict.pop(k, None)
    model.load_state_dict(state_dict)
    model.eval()
    our_outputs = model.forward(tokens)[0]

    assert their_output.shape == our_outputs.shape
    assert (their_output == our_outputs).all().item()
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
    model.save_pretrained(pytorch_dump_folder_path)