def average_models(
    old_model: nsp.EncoderDecoderWPointerModel,
    new_model: nsp.EncoderDecoderWPointerModel,
    new_model_weight: float,
    inplace=False,
):
    """Averages new_model and old_model parameters.

    Parameters are averaged with weights new_model_weight and (1 - new_model_weight).
    If inplace, mutates new_model object.
    """
    if not (0 <= new_model_weight <= 1):
        raise ValueError(
            f"new_model_weight should be between 0 and 1, got {new_model_weight} instead."
        )
    if not inplace:
        new_model_tmp = nsp.EncoderDecoderWPointerModel(new_model.config)
        new_model_tmp.load_state_dict(new_model.state_dict())
        new_model = new_model_tmp

    old_named_params = old_model.state_dict()
    new_named_params = new_model.state_dict()

    if old_named_params.keys() != new_named_params.keys():
        raise RuntimeError("Model should have the same parameters")

    for name, old_param in old_named_params.items():
        new_named_params[name] = (new_model_weight * new_named_params[name] +
                                  (1 - new_model_weight) * old_param)

    new_model.load_state_dict(new_named_params, strict=True)
    return new_model
def iterative_prediction(
    model: nsp.EncoderDecoderWPointerModel,
    dataloader,
    schema_tokenizer: nsp.TopSchemaTokenizer,
    max_len,
    num_beams,
    device="cpu",
    return_tokens=False,
):
    """Executes inference-time prediction loop.

    Returns:
        A tuple of two elements (predictions_ids, predictions_str)
            predictions_ids: list of np.arrays
            predictions_str: list of strings if return_tokens is False
                or list of lists of strings if return_tokens is True
    """
    model = model.to(device)

    predictions_ids = []
    predictions_str = []
    text_tokenizer = schema_tokenizer.src_tokenizer

    for batch in tqdm(dataloader, desc="generation"):
        prediction_batch: torch.LongTensor = model.generate(
            input_ids=batch["input_ids"].to(device),
            pointer_mask=batch["pointer_mask"].to(device),
            attention_mask=batch["attention_mask"].to(device),
            max_length=max_len,
            num_beams=num_beams,
            pad_token_id=text_tokenizer.pad_token_id,
            bos_token_id=schema_tokenizer.bos_token_id,
            eos_token_id=schema_tokenizer.eos_token_id,
        )

        for i, prediction in enumerate(prediction_batch):
            prediction = [
                p for p in prediction.cpu().numpy()
                if p not in schema_tokenizer.special_ids
            ]
            predictions_ids.append(prediction)

            prediction_str: str = schema_tokenizer.decode(
                prediction,
                batch["input_ids"][i],
                skip_special_tokens=True,
                return_tokens=return_tokens,
            )
            predictions_str.append(prediction_str)

    return predictions_ids, predictions_str
    def test_shape_on_random_data(self):
        set_seed(42)

        bs = 3
        src_len = 5
        tgt_len = 7

        encoder_config = transformers.BertConfig(
            hidden_size=11,
            intermediate_size=44,
            vocab_size=17,
            num_hidden_layers=1,
            num_attention_heads=1,
        )
        encoder = transformers.BertModel(encoder_config)

        # decoder accepts vocabulary of schema vocab + pointer embeddings
        decoder_config = transformers.BertConfig(
            hidden_size=11,
            intermediate_size=44,
            vocab_size=23,
            is_decoder=True,
            num_hidden_layers=1,
            num_attention_heads=1,
        )
        decoder = transformers.BertModel(decoder_config)

        # logits are projected into schema vocab and combined with pointer scores
        max_pointer = src_len + 3
        model = EncoderDecoderWPointerModel(encoder=encoder,
                                            decoder=decoder,
                                            max_src_len=max_pointer)

        x_enc = torch.randint(0, encoder_config.vocab_size, size=(bs, src_len))
        x_dec = torch.randint(0, decoder_config.vocab_size, size=(bs, tgt_len))

        out = model(input_ids=x_enc, decoder_input_ids=x_dec)

        # different encoders return different number of outputs
        # e.g. BERT returns two, but DistillBERT only one
        self.assertGreaterEqual(len(out), 4)

        schema_vocab = decoder_config.vocab_size - max_pointer

        combined_logits = out[0]
        expected_shape = (bs, tgt_len, schema_vocab + src_len)
        self.assertEqual(combined_logits.shape, expected_shape)

        decoder_hidden = out[1]
        expected_shape = (bs, tgt_len, decoder_config.hidden_size)
        self.assertEqual(decoder_hidden.shape, expected_shape)

        combined_logits = out[2]
        expected_shape = (bs, decoder_config.hidden_size)
        self.assertEqual(combined_logits.shape, expected_shape)

        encoder_hidden = out[3]
        expected_shape = (bs, src_len, encoder_config.hidden_size)
        self.assertEqual(encoder_hidden.shape, expected_shape)
    def test_save_load(self):
        src_vocab_size = 23
        tgt_vocab_size = 17

        model = EncoderDecoderWPointerModel.from_parameters(
            layers=1,
            hidden=32,
            heads=2,
            src_vocab_size=src_vocab_size,
            tgt_vocab_size=tgt_vocab_size,
            max_src_len=7,
            dropout=0,
        )

        input_ids = torch.randint(src_vocab_size, size=(3, 7))
        tgt_sequence = torch.randint(tgt_vocab_size, size=(3, 11))
        decoder_input_ids = tgt_sequence[:, :-1].contiguous()
        labels = tgt_sequence[:, 1:].contiguous()

        expected_output = model(input_ids=input_ids,
                                decoder_input_ids=decoder_input_ids,
                                labels=labels)

        os.mkdir(self.output_dir)
        model.save_pretrained(self.output_dir)

        loaded_model = EncoderDecoderWPointerModel.from_pretrained(
            self.output_dir)
        self.assertDictEqual(model.config.to_dict(),
                             loaded_model.config.to_dict())
        for i, (p1, p2) in enumerate(
                zip(model.parameters(), loaded_model.parameters())):
            self.assertTrue(torch.allclose(p1, p2))

        output = loaded_model(input_ids=input_ids,
                              decoder_input_ids=decoder_input_ids,
                              labels=labels)

        self.assertEqual(len(output), len(expected_output))
        self.assertTrue(torch.allclose(expected_output[0], output[0]))  # loss
        self.assertTrue(torch.allclose(expected_output[1],
                                       output[1]))  # logits
    def test_loss_computation(self):
        torch.manual_seed(42)
        src_vocab_size = 17
        tgt_vocab_size = 23

        encoder_config = transformers.BertConfig(
            hidden_size=11,
            intermediate_size=44,
            vocab_size=src_vocab_size,
            num_hidden_layers=1,
            num_attention_heads=1,
        )
        encoder = transformers.BertModel(encoder_config)

        max_position = 7
        decoder_config = transformers.BertConfig(
            hidden_size=11,
            intermediate_size=44,
            vocab_size=tgt_vocab_size + max_position,
            is_decoder=True,
            num_hidden_layers=1,
            num_attention_heads=1,
        )
        decoder = transformers.BertModel(decoder_config)

        model = EncoderDecoderWPointerModel(encoder=encoder,
                                            decoder=decoder,
                                            max_src_len=7)

        # similar to real data
        src_seq = torch.LongTensor([[1, 6, 12, 15, 2, 0, 0],
                                    [1, 6, 12, 15, 5, 3, 2]])
        tgt_seq = torch.LongTensor([
            [8, 6, 4, 10, 11, 8, 5, 1, 12, 7, 7, 0, 0],
            [8, 6, 4, 10, 11, 8, 5, 1, 12, 13, 14, 7, 7],
        ])
        mask = torch.FloatTensor([[0, 1, 1, 1, 0, 0, 0], [0, 1, 1, 1, 1, 1,
                                                          0]])

        loss = model(
            input_ids=src_seq,
            decoder_input_ids=tgt_seq,
            pointer_mask=mask,
            labels=tgt_seq,
        )[0]

        self.assertEqual(loss.shape, torch.Size([]))
        self.assertEqual(loss.dtype, torch.float32)
        self.assertGreater(loss, 0)
    def test_move_norm(self):
        src_vocab_size = 23
        tgt_vocab_size = 17

        model = EncoderDecoderWPointerModel.from_parameters(
            layers=1,
            hidden=32,
            heads=2,
            src_vocab_size=src_vocab_size,
            tgt_vocab_size=tgt_vocab_size,
            max_src_len=7,
            dropout=0,
            move_norm=0.1,
        )

        self.assertTrue(model.initial_params is not None)
        # check that model parameters do not include initial_params
        self.assertEqual(len(list(model.parameters())),
                         len(model.initial_params))

        # check that model updates do not change initial_params
        bs, src_len, tgt_len = 3, 5, 7
        x_enc = torch.randint(0, src_vocab_size, size=(bs, src_len))
        x_dec = torch.randint(0, tgt_vocab_size, size=(bs, tgt_len))

        dec_inp = x_dec[:, :-1].contiguous()
        labels = x_dec[:, 1:].contiguous()

        optimizer = torch.optim.SGD(model.parameters(), 1e-3)

        out = model(input_ids=x_enc, decoder_input_ids=dec_inp, labels=labels)

        loss = out[0]
        loss.backward()

        optimizer.step()

        for n, p1 in model.named_parameters():
            if "pooler" in n:
                # we do not use pooler weights
                continue

            p2 = model.initial_params[n]
            self.assertTrue(torch.any(p2 != p1), msg=n)

        # check norm computation
        norm = model.get_move_norm()
        self.assertGreater(norm, 0)
    def test_shape_on_real_data_batched(self):
        set_seed(42)
        src_vocab_size = 17
        tgt_vocab_size = 23
        max_position = 7

        encoder_config = transformers.BertConfig(
            hidden_size=11,
            intermediate_size=44,
            vocab_size=src_vocab_size,
            num_hidden_layers=1,
            num_attention_heads=1,
        )
        encoder = transformers.BertModel(encoder_config)

        decoder_config = transformers.BertConfig(
            hidden_size=11,
            intermediate_size=44,
            vocab_size=tgt_vocab_size + max_position,
            is_decoder=True,
            num_hidden_layers=1,
            num_attention_heads=1,
        )
        decoder = transformers.BertModel(decoder_config)

        model = EncoderDecoderWPointerModel(encoder=encoder,
                                            decoder=decoder,
                                            max_src_len=max_position)

        # similar to real data
        src_seq = torch.LongTensor([[1, 6, 12, 15, 2, 0, 0],
                                    [1, 6, 12, 15, 5, 3, 2]])
        tgt_seq = torch.LongTensor([
            [8, 6, 4, 10, 11, 8, 5, 1, 12, 7, 7, 0, 0],
            [8, 6, 4, 10, 11, 8, 5, 1, 12, 13, 14, 7, 7],
        ])
        mask = torch.FloatTensor([[0, 1, 1, 1, 0, 0, 0], [0, 1, 1, 1, 1, 1,
                                                          0]])

        combined_logits = model(input_ids=src_seq,
                                decoder_input_ids=tgt_seq,
                                pointer_mask=mask)[0]

        expected_shape = (2, tgt_seq.shape[1],
                          tgt_vocab_size + src_seq.shape[1])
        self.assertEqual(combined_logits.shape, expected_shape)
    def test_shape_on_real_data(self):
        set_seed(42)
        src_vocab_size = 17
        tgt_vocab_size = 23
        max_position = 5

        encoder_config = transformers.BertConfig(
            hidden_size=11,
            intermediate_size=44,
            vocab_size=src_vocab_size,
            num_hidden_layers=1,
            num_attention_heads=1,
        )
        encoder = transformers.BertModel(encoder_config)

        decoder_config = transformers.BertConfig(
            hidden_size=11,
            intermediate_size=44,
            vocab_size=tgt_vocab_size + max_position,
            is_decoder=True,
            num_hidden_layers=1,
            num_attention_heads=1,
        )
        decoder = transformers.BertModel(decoder_config)

        model = EncoderDecoderWPointerModel(encoder=encoder,
                                            decoder=decoder,
                                            max_src_len=max_position)

        # similar to real data
        # e.g. '[CLS] Directions to Lowell [SEP]'
        src_seq = torch.LongTensor([[1, 6, 12, 15, 2]])
        # e.g. '[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell]]'
        tgt_seq = torch.LongTensor([[8, 6, 4, 10, 11, 8, 5, 1, 12, 7, 7]])
        mask = torch.FloatTensor([[0, 1, 1, 1, 0]])

        combined_logits = model(input_ids=src_seq,
                                decoder_input_ids=tgt_seq,
                                pointer_mask=mask)[0]

        expected_shape = (1, tgt_seq.shape[1],
                          tgt_vocab_size + src_seq.shape[1])
        self.assertEqual(combined_logits.shape, expected_shape)
    def test_register_weight_consolidation_buffer(self):
        src_vocab_size = 23
        tgt_vocab_size = 17

        model = EncoderDecoderWPointerModel.from_parameters(
            layers=1,
            hidden=32,
            heads=2,
            src_vocab_size=src_vocab_size,
            tgt_vocab_size=tgt_vocab_size,
            max_src_len=7,
            dropout=0,
            track_grad_square=True,
        )

        bs, src_len, tgt_len = 3, 5, 7
        x_enc = torch.randint(0, src_vocab_size, size=(bs, src_len))
        x_dec = torch.randint(0, tgt_vocab_size, size=(bs, tgt_len))

        dec_inp = x_dec[:, :-1].contiguous()
        labels = x_dec[:, 1:].contiguous()

        _ = model(input_ids=x_enc, decoder_input_ids=dec_inp, labels=labels)

        model.register_weight_consolidation_buffer()

        self.assertIsNotNone(model.omega)
        self.assertFalse(
            torch.all(
                torch.isinf(model.omega[
                    "encoder.encoder.layer.0.attention.self.value.weight"])))

        for name, omega in model.omega.items():
            self.assertTrue(torch.all(omega >= 0))

        state_dict = model.state_dict()
        self.assertTrue(
            torch.all(state_dict[
                "omega_encoder_encoder_layer_0_attention_self_value_weight"] !=
                      0))
    def test_update_grad_squared(self):
        src_vocab_size = 23
        tgt_vocab_size = 17

        model = EncoderDecoderWPointerModel.from_parameters(
            layers=1,
            hidden=32,
            heads=2,
            src_vocab_size=src_vocab_size,
            tgt_vocab_size=tgt_vocab_size,
            max_src_len=7,
            dropout=0,
            move_norm=100,
            track_grad_square=True,
        )

        bs, src_len, tgt_len = 3, 5, 7
        x_enc = torch.randint(0, src_vocab_size, size=(bs, src_len))
        x_dec = torch.randint(0, tgt_vocab_size, size=(bs, tgt_len))

        dec_inp = x_dec[:, :-1].contiguous()
        labels = x_dec[:, 1:].contiguous()

        out = model(input_ids=x_enc, decoder_input_ids=dec_inp, labels=labels)

        loss = out[0]

        grad_squared = deepcopy(model.grad_squared)
        assert grad_squared is not None

        model._update_grad_squared(loss)

        n_changed = 0

        for name, grad2 in grad_squared.items():
            if not torch.allclose(model.grad_squared[name], grad2):
                n_changed += 1

        self.assertGreater(n_changed, 40)
Пример #11
0
    # NOTE: this dataset object does not have labels
    dataset: PointerDataset = make_test_dataset(args.data,
                                                schema_tokenizer,
                                                max_len=args.src_max_len)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        collate_fn=Seq2SeqDataCollator(
            pad_id=text_tokenizer.pad_token_id).collate_batch,
        num_workers=8,
    )

    logger.info(f"Maximum source text length {dataset.get_max_len()[0]}")

    model = EncoderDecoderWPointerModel.from_pretrained(args.model).to(
        args.device)
    model.eval()

    predictions_ids, predictions_str = cli_utils.iterative_prediction(
        model=model,
        dataloader=dataloader,
        schema_tokenizer=schema_tokenizer,
        max_len=args.tgt_max_len,
        num_beams=args.num_beams,
        device=args.device,
    )

    # predictions should be postprocessed for evaluation (reproduce TOP format tokenization)
    predictions_str = [
        schema_tokenizer.postprocess(p) for p in predictions_str
    ]
    def test_get_weight_consolidation(self):
        src_vocab_size = 23
        tgt_vocab_size = 17

        model = EncoderDecoderWPointerModel.from_parameters(
            layers=1,
            hidden=32,
            heads=2,
            src_vocab_size=src_vocab_size,
            tgt_vocab_size=tgt_vocab_size,
            max_src_len=7,
            dropout=0,
            track_grad_square=True,
        )

        bs, src_len, tgt_len = 3, 5, 7
        x_enc = torch.randint(0, src_vocab_size, size=(bs, src_len))
        x_dec = torch.randint(0, tgt_vocab_size, size=(bs, tgt_len))

        dec_inp = x_dec[:, :-1].contiguous()
        labels = x_dec[:, 1:].contiguous()

        optimizer = torch.optim.Adam(model.parameters())

        for _ in range(3):
            optimizer.zero_grad()

            out = model(input_ids=x_enc,
                        decoder_input_ids=dec_inp,
                        labels=labels)
            loss = out[0]

            loss.backward()
            optimizer.step()

        model.register_weight_consolidation_buffer()

        state_dict = model.state_dict()

        new_model = EncoderDecoderWPointerModel.from_parameters(
            layers=1,
            hidden=32,
            heads=2,
            src_vocab_size=src_vocab_size,
            tgt_vocab_size=tgt_vocab_size,
            max_src_len=7,
            dropout=0,
            weight_consolidation=100,
        )

        new_model.load_state_dict(state_dict)

        reg = new_model._get_weight_consolidation()

        # before finetuning reg is 0
        self.assertIsNotNone(reg)
        self.assertFalse(torch.isinf(reg))
        self.assertTrue(reg == 0)

        optimizer = torch.optim.Adam(new_model.parameters())

        for _ in range(3):
            optimizer.zero_grad()

            out = new_model(input_ids=x_enc,
                            decoder_input_ids=dec_inp,
                            labels=labels)
            loss = out[0]

            loss.backward()
            optimizer.step()

        reg = new_model._get_weight_consolidation()
        self.assertTrue(reg > 0)
    def test_move_norm_update(self):
        """Test that move norm affects optimization"""

        src_vocab_size = 23
        tgt_vocab_size = 17

        model = EncoderDecoderWPointerModel.from_parameters(
            layers=1,
            hidden=32,
            heads=2,
            src_vocab_size=src_vocab_size,
            tgt_vocab_size=tgt_vocab_size,
            max_src_len=7,
            dropout=0,
            move_norm=100,
        )

        model_copy = deepcopy(model)
        model_copy.config.move_norm = None
        del model_copy.initial_params

        model_copy2 = deepcopy(model)
        model_copy2.config.move_norm = None
        del model_copy2.initial_params

        for (n1, p1), (n2, p2) in zip(model.named_parameters(),
                                      model_copy.named_parameters()):
            assert n1 == n2
            assert torch.allclose(p1, p2)

        # check that model updates do not change initial_params
        bs, src_len, tgt_len = 3, 5, 7
        x_enc = torch.randint(0, src_vocab_size, size=(bs, src_len))
        x_dec = torch.randint(0, tgt_vocab_size, size=(bs, tgt_len))

        dec_inp = x_dec[:, :-1].contiguous()
        labels = x_dec[:, 1:].contiguous()

        losses = []
        for _model in [model, model_copy, model_copy2]:
            for _ in range(2):
                # at the first update move_norm = 0 as model == initial
                optimizer = torch.optim.SGD(_model.parameters(), 1e-3)

                out = _model(input_ids=x_enc,
                             decoder_input_ids=dec_inp,
                             labels=labels)

                loss = out[0]
                loss.backward()

                optimizer.step()

            losses.append(loss.detach())

        self.assertTrue(torch.allclose(losses[1], losses[2]),
                        msg="test is not deterministic")
        self.assertFalse(torch.allclose(losses[0], losses[1]))

        for (n1, p1), (n2, p2), (n3,
                                 p3) in zip(model.named_parameters(),
                                            model_copy.named_parameters(),
                                            model_copy2.named_parameters()):
            assert n1 == n2 == n3
            if "pooler" in n1:
                # we do not use pooler weights
                continue

            self.assertTrue(torch.allclose(p2, p3),
                            msg=f"test is not deterministic")
            self.assertFalse(torch.allclose(p1, p2), msg=n1)
    def test_freeze(self):
        src_vocab_size = 23
        tgt_vocab_size = 17

        model = EncoderDecoderWPointerModel.from_parameters(
            layers=1,
            hidden=32,
            heads=2,
            src_vocab_size=src_vocab_size,
            tgt_vocab_size=tgt_vocab_size,
            max_src_len=7,
            dropout=0,
        )

        # check that all parameters are trainable
        for name, param in model.named_parameters():
            self.assertTrue(param.requires_grad, msg=name)

        model.freeze_encoder()
        model.freeze_decoder()
        model.freeze_head()

        # check that all parameters are frozen
        for name, param in model.named_parameters():
            self.assertFalse(param.requires_grad, msg=name)

        model.freeze_encoder(freeze=False)
        model.freeze_decoder(freeze=False)
        model.freeze_head(freeze=False)

        # check that all parameters are trainable again
        for name, param in model.named_parameters():
            self.assertTrue(param.requires_grad, msg=name)

        # check that initial optimizer state does not interfere with the freezing
        bs, src_len, tgt_len = 3, 5, 7
        x_enc = torch.randint(0, src_vocab_size, size=(bs, src_len))
        x_dec = torch.randint(0, tgt_vocab_size, size=(bs, tgt_len))

        dec_inp = x_dec[:, :-1].contiguous()
        labels = x_dec[:, 1:].contiguous()

        for opt_class in [torch.optim.SGD, torch.optim.Adam]:
            with self.subTest(repr(opt_class)):
                optimizer = opt_class(model.parameters(), lr=1e-3)

                for _ in range(5):
                    out = model(input_ids=x_enc,
                                decoder_input_ids=dec_inp,
                                labels=labels)

                    loss = out[0]
                    loss.backward()

                    optimizer.step()
                    optimizer.zero_grad()

                model.freeze_encoder(freeze=True)
                model.freeze_decoder(freeze=True)

                model_copy = deepcopy(model)

                # do multiple optimizer updates to ensure that ADAM betas do not interfere with the freezing
                for _ in range(5):
                    optimizer.zero_grad()

                    out = model(input_ids=x_enc,
                                decoder_input_ids=dec_inp,
                                labels=labels)

                    loss = out[0]
                    loss.backward()

                    optimizer.step()
                    optimizer.zero_grad()

                for (n1,
                     p1), (n2,
                           p2) in zip(model.encoder.named_parameters(),
                                      model_copy.encoder.named_parameters()):
                    assert n1 == n2
                    self.assertFalse(p1.requires_grad)
                    self.assertTrue(torch.allclose(p1, p2),
                                    msg=f"Optimizer state changed {n1}")

                for (n1,
                     p1), (n2,
                           p2) in zip(model.decoder.named_parameters(),
                                      model_copy.decoder.named_parameters()):
                    assert n1 == n2
                    self.assertFalse(p1.requires_grad)
                    self.assertTrue(torch.allclose(p1, p2),
                                    msg=f"Optimizer state changed {n1}")