def get_optimizer_and_schedule(args, model: EncoderDecoderModel):
    # 预训练参数和初始化参数使用不同的学习率
    if args.ngpu > 1:
        model = model.module
    init_params_id = []
    for layer in model.decoder.transformer.h:
        init_params_id.extend(list(map(id, layer.crossattention.parameters())))
        init_params_id.extend(list(map(id, layer.ln_cross_attn.parameters())))
    pretrained_params = filter(
        lambda p: id(p) not in init_params_id, model.parameters()
    )
    initialized_params = filter(lambda p: id(p) in init_params_id, model.parameters())
    params_setting = [
        {"params": initialized_params},
        {"params": pretrained_params, "lr": args.finetune_lr},
    ]
    optimizer = optim.Adam(params_setting, lr=args.lr)
    schedule = get_cosine_schedule_with_warmup(
        optimizer,
        num_training_steps=args.num_training_steps,
        num_warmup_steps=args.num_warmup_steps,
    )
    return optimizer, schedule
    def create_and_check_encoder_decoder_shared_weights(
            self, config, input_ids, attention_mask, encoder_hidden_states,
            decoder_config, decoder_input_ids, decoder_attention_mask, labels,
            **kwargs):
        torch.manual_seed(0)
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        model = EncoderDecoderModel(encoder=encoder_model,
                                    decoder=decoder_model)
        model.to(torch_device)
        model.eval()
        # load state dict copies weights but does not tie them
        decoder_state_dict = model.decoder._modules[
            model.decoder.base_model_prefix].state_dict()
        model.encoder.load_state_dict(decoder_state_dict, strict=False)

        torch.manual_seed(0)
        tied_encoder_model, tied_decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        config = EncoderDecoderConfig.from_encoder_decoder_configs(
            tied_encoder_model.config,
            tied_decoder_model.config,
            tie_encoder_decoder=True)
        tied_model = EncoderDecoderModel(encoder=tied_encoder_model,
                                         decoder=tied_decoder_model,
                                         config=config)
        tied_model.to(torch_device)
        tied_model.eval()

        model_result = model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        tied_model_result = tied_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        # check that models has less parameters
        self.assertLess(sum(p.numel() for p in tied_model.parameters()),
                        sum(p.numel() for p in model.parameters()))
        random_slice_idx = ids_tensor((1, ), model_result[0].shape[-1]).item()

        # check that outputs are equal
        self.assertTrue(
            torch.allclose(model_result[0][0, :, random_slice_idx],
                           tied_model_result[0][0, :, random_slice_idx],
                           atol=1e-4))

        # check that outputs after saving and loading are equal
        with tempfile.TemporaryDirectory() as tmpdirname:
            tied_model.save_pretrained(tmpdirname)
            tied_model = EncoderDecoderModel.from_pretrained(tmpdirname)
            tied_model.to(torch_device)
            tied_model.eval()

            # check that models has less parameters
            self.assertLess(sum(p.numel() for p in tied_model.parameters()),
                            sum(p.numel() for p in model.parameters()))
            random_slice_idx = ids_tensor((1, ),
                                          model_result[0].shape[-1]).item()

            tied_model_result = tied_model(
                input_ids=input_ids,
                decoder_input_ids=decoder_input_ids,
                attention_mask=attention_mask,
                decoder_attention_mask=decoder_attention_mask,
            )

            # check that outputs are equal
            self.assertTrue(
                torch.allclose(model_result[0][0, :, random_slice_idx],
                               tied_model_result[0][0, :, random_slice_idx],
                               atol=1e-4))
Ejemplo n.º 3
0
class BERT2BERTTrainer(pl.LightningModule):
    def __init__(self, lr, **args):
        super(BERT2BERTTrainer, self).__init__()
        self.save_hyperparameters()
        encoder = BertGenerationEncoder.from_pretrained(
            "ckiplab/bert-base-chinese",
            bos_token_id=101,
            eos_token_id=102,
            # force_download=True
        )
        decoder = BertGenerationDecoder.from_pretrained(
            "ckiplab/bert-base-chinese",
            add_cross_attention=True,
            is_decoder=True,
            bos_token_id=101,
            eos_token_id=102)

        self.bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)
        if args['with_keywords_loss']:
            self.loss_fct2 = KeywordsLoss(alpha=args['keywords_loss_alpha'],
                                          loss_fct=args['keywords_loss_fct'])

    def generate(self, inputs_ids, attention_mask=None, **kwargs):
        inputs_ids = inputs_ids.to(self.device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.device)
        with torch.no_grad():
            return self.bert2bert.generate(input_ids=inputs_ids,
                                           attention_mask=attention_mask,
                                           bos_token_id=101,
                                           min_length=100,
                                           eos_token_id=102,
                                           pad_token_id=0,
                                           **kwargs).detach().cpu().tolist()

    def forward(self, inputs):
        with torch.no_grad():
            return self.bert2bert(**inputs)

    def training_step(self, inputs, batch_idx):
        title, body = inputs
        title['input_ids'] = title['input_ids'].squeeze(1)
        title['attention_mask'] = title['attention_mask'].squeeze(1)
        body['input_ids'] = body['input_ids'].squeeze(1)
        body['attention_mask'] = body['attention_mask'].squeeze(1)
        ret = self.bert2bert(input_ids=title['input_ids'],
                             attention_mask=title['attention_mask'],
                             decoder_input_ids=body['input_ids'],
                             decoder_attention_mask=body['attention_mask'],
                             labels=body['input_ids'])
        loss2 = self.loss_fct2(
            ret.logits,
            title['input_ids']) if self.hparams['with_keywords_loss'] else 0.
        self.log('keyword_loss', loss2, prog_bar=True)
        self.log('clm_loss', ret.loss, prog_bar=True)

        return {'loss': ret.loss + loss2, 'keyword_loss': loss2}

    def training_epoch_end(self, outputs):
        mean_loss = torch.stack([x['loss']
                                 for x in outputs]).reshape(-1).mean()
        self.log('mean_loss', mean_loss)

    def configure_optimizers(self):
        opt = optim.AdamW(self.bert2bert.parameters(), lr=self.hparams['lr'])
        return opt

    @staticmethod
    def add_parser_args(parser):
        # parser.add_argument('--lr', type=float)
        parser.add_argument('--with_keywords_loss', action='store_true')
        parser.add_argument('--keywords_loss_alpha',
                            type=float,
                            default=0.7,
                            help='float > 0.5')
        parser.add_argument('--keywords_loss_fct',
                            type=str,
                            default='kldiv',
                            help='kldiv or mse')

        return parser
Ejemplo n.º 4
0
    pad_token_id = decoder_tokenizer.vocab["[PAD]"],
    cls_token_id = decoder_tokenizer.vocab["[CLS]"],
    mask_token_id = decoder_tokenizer.vocab["[MASK]"],
    bos_token_id = decoder_tokenizer.vocab["[BOS]"],
    eos_token_id = decoder_tokenizer.vocab["[EOS]"],
    )
# Initialize a brand new bert-based decoder.
decoder = BertGenerationDecoder(config=decoder_config)

# Setup enc-decoder mode.
bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)
bert2bert.config.decoder_start_token_id=decoder_tokenizer.vocab["[CLS]"]
bert2bert.config.pad_token_id=decoder_tokenizer.vocab["[PAD]"]

# Elementary Training.
optimizer = torch.optim.Adam(bert2bert.parameters(), lr=0.000001)
bert2bert.cuda()

for epoch in range(30):
    print("*"*50, "Epoch", epoch, "*"*50)
    if True:
        for batch in tqdm(sierra_dl):
            # tokenize commands and goals.
            inputs = encoder_tokenizer(batch["command"], add_special_tokens=True, return_tensors="pt", padding=True, truncation=True)
            labels = decoder_tokenizer(batch["symbolic_plan_processed"], return_tensors="pt", padding=True, max_length=sierra_ds.max_plan_length, truncation=True, add_special_tokens=True, )

            # Move to GPU.
            for key,item in inputs.items():
                if type(item).__name__ == "Tensor":
                    inputs[key] = item.cuda()
            for key, item in labels.items():
Ejemplo n.º 5
0
def train_model(epochs=10,
                num_gradients_accumulation=4,
                batch_size=4,
                gpu_id=0,
                lr=1e-5,
                load_dir='/content/BERT checkpoints'):
    # make sure your model is on GPU
    device = torch.device(f"cuda:{gpu_id}")

    # ------------------------LOAD MODEL-----------------
    print('load the model....')
    bert_encoder = BertConfig.from_pretrained('bert-base-uncased')
    bert_decoder = BertConfig.from_pretrained('bert-base-uncased',
                                              is_decoder=True)
    config = EncoderDecoderConfig.from_encoder_decoder_configs(
        bert_encoder, bert_decoder)
    model = EncoderDecoderModel(config)

    model = model.to(device)

    print('load success')
    # ------------------------END LOAD MODEL--------------

    # ------------------------LOAD TRAIN DATA------------------
    train_data = torch.load("/content/train_data.pth")
    train_dataset = TensorDataset(*train_data)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  shuffle=True,
                                  batch_size=batch_size)
    val_data = torch.load("/content/validate_data.pth")
    val_dataset = TensorDataset(*val_data)
    val_dataloader = DataLoader(dataset=val_dataset,
                                shuffle=True,
                                batch_size=batch_size)
    # ------------------------END LOAD TRAIN DATA--------------

    # ------------------------SET OPTIMIZER-------------------
    num_train_optimization_steps = len(
        train_dataset) * epochs // batch_size // num_gradients_accumulation

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(
        optimizer_grouped_parameters,
        lr=lr,
        weight_decay=0.01,
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_train_optimization_steps // 10,
        num_training_steps=num_train_optimization_steps)

    # ------------------------START TRAINING-------------------
    update_count = 0

    start = time.time()
    print('start training....')
    for epoch in range(epochs):
        # ------------------------training------------------------
        model.train()
        losses = 0
        times = 0

        print('\n' + '-' * 20 + f'epoch {epoch}' + '-' * 20)
        for batch in tqdm(train_dataloader):
            batch = [item.to(device) for item in batch]

            encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch
            logits = model(input_ids=encoder_input,
                           attention_mask=mask_encoder_input,
                           decoder_input_ids=decoder_input,
                           decoder_attention_mask=mask_decoder_input)

            out = logits[0][:, :-1].contiguous()
            target = decoder_input[:, 1:].contiguous()
            target_mask = mask_decoder_input[:, 1:].contiguous()
            loss = util.sequence_cross_entropy_with_logits(out,
                                                           target,
                                                           target_mask,
                                                           average="token")
            loss.backward()

            losses += loss.item()
            times += 1

            update_count += 1

            if update_count % num_gradients_accumulation == num_gradients_accumulation - 1:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_grad_norm)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
        end = time.time()
        print(f'time: {(end - start)}')
        print(f'loss: {losses / times}')
        start = end

        # ------------------------validate------------------------
        model.eval()

        perplexity = 0
        batch_count = 0
        print('\nstart calculate the perplexity....')

        with torch.no_grad():
            for batch in tqdm(val_dataloader):
                batch = [item.to(device) for item in batch]

                encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch
                logits = model(input_ids=encoder_input,
                               attention_mask=mask_encoder_input,
                               decoder_input_ids=decoder_input,
                               decoder_attention_mask=mask_decoder_input)

                out = logits[0][:, :-1].contiguous()
                target = decoder_input[:, 1:].contiguous()
                target_mask = mask_decoder_input[:, 1:].contiguous()
                # print(out.shape,target.shape,target_mask.shape)
                loss = util.sequence_cross_entropy_with_logits(out,
                                                               target,
                                                               target_mask,
                                                               average="token")
                perplexity += np.exp(loss.item())
                batch_count += 1

        print(f'\nvalidate perplexity: {perplexity / batch_count}')

        torch.save(
            model.state_dict(),
            os.path.join(os.path.abspath('.'), load_dir,
                         "model-" + str(epoch) + ".pth"))
Ejemplo n.º 6
0
decoder = BertForMaskedLM(config=decoder_config)

# Define encoder decoder model
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
model.to(device)


def count_parameters(mdl):
    return sum(p.numel() for p in mdl.parameters() if p.requires_grad)


print(f'The encoder has {count_parameters(encoder):,} trainable parameters')
print(f'The decoder has {count_parameters(decoder):,} trainable parameters')
print(f'The model has {count_parameters(model):,} trainable parameters')

optimizer = optim.Adam(model.parameters(), lr=modelparams['lr'])
criterion = nn.NLLLoss(ignore_index=de_tokenizer.pad_token_id)

num_train_batches = len(train_dataloader)
num_valid_batches = len(valid_dataloader)


def compute_loss(predictions, targets):
    """Compute our custom loss"""
    predictions = predictions[:, :-1, :].contiguous()
    targets = targets[:, 1:]

    rearranged_output = predictions.view(
        predictions.shape[0] * predictions.shape[1], -1)
    rearranged_target = targets.contiguous().view(-1)