Exemplo n.º 1
0
def train(train_iter, val_iter, model):
    opt = AdamW(model.parameters(), lr=1e-4, eps=1e-8)
    scheduler = WarmupLinearSchedule(opt, warmup_steps=20, t_total=2500)
    model.train()
    losses = []
    for i, ex in enumerate(train_iter):
        opt.zero_grad()
        words, mapper, _ = ex.word
        label, lengths = ex.head
        batch, _ = label.shape

        # Model
        final = model(words.cuda(), mapper)
        for b in range(batch):
            final[b, lengths[b]-1:, :] = 0
            final[b, :, lengths[b]-1:] = 0

        if not lengths.max() <= final.shape[1] + 1:
            print("fail")
            continue
        dist = DependencyCRF(final, lengths=lengths)

        labels = dist.struct.to_parts(label, lengths=lengths).type_as(final)
        log_prob = dist.log_prob(labels)

        loss = log_prob.sum()
        (-loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        opt.step()
        scheduler.step()
        losses.append(loss.detach())
        if i % 50 == 1:
            print(-torch.tensor(losses).mean(), words.shape)
            losses = []
        if i % 600 == 500:
            validate(val_iter)
Exemplo n.º 2
0
                                       max_norm=max_grad_norm)     
            optimizer.step()
            scheduler.step()

        else:
            b_tags = [tag[mask] for mask, tag in zip(b_label_masks, b_tags)]
            b_tags = pad_sequence(b_tags, batch_first=True, padding_value=0)

            loss_main, logits, labels, final = model(b_input_ids, b_tags, labels=b_labels, label_masks=b_label_masks)

            if not lengths.max() <= final.shape[1]:
                dep_loss = 0
            else:
                dist = DependencyCRF(final, lengths=lengths)
                dep_labels = dist.struct.to_parts(b_tags, lengths=lengths).type_as(final)   # [BATCH_SIZE, lengths, lengths]
                log_prob = dist.log_prob(dep_labels)

                dep_loss = log_prob.mean() #sum()

            if dep_loss < 0 :
                loss = loss_main -dep_loss/dep_loss_factor  
            else:
                loss = loss_main

            loss.backward()
            tr_loss += loss.item()
            nb_tr_steps += 1
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            scheduler.step()