def train(train_iter, val_iter, model): opt = AdamW(model.parameters(), lr=1e-4, eps=1e-8) scheduler = WarmupLinearSchedule(opt, warmup_steps=20, t_total=2500) model.train() losses = [] for i, ex in enumerate(train_iter): opt.zero_grad() words, mapper, _ = ex.word label, lengths = ex.head batch, _ = label.shape # Model final = model(words.cuda(), mapper) for b in range(batch): final[b, lengths[b]-1:, :] = 0 final[b, :, lengths[b]-1:] = 0 if not lengths.max() <= final.shape[1] + 1: print("fail") continue dist = DependencyCRF(final, lengths=lengths) labels = dist.struct.to_parts(label, lengths=lengths).type_as(final) log_prob = dist.log_prob(labels) loss = log_prob.sum() (-loss).backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() scheduler.step() losses.append(loss.detach()) if i % 50 == 1: print(-torch.tensor(losses).mean(), words.shape) losses = [] if i % 600 == 500: validate(val_iter)
max_norm=max_grad_norm) optimizer.step() scheduler.step() else: b_tags = [tag[mask] for mask, tag in zip(b_label_masks, b_tags)] b_tags = pad_sequence(b_tags, batch_first=True, padding_value=0) loss_main, logits, labels, final = model(b_input_ids, b_tags, labels=b_labels, label_masks=b_label_masks) if not lengths.max() <= final.shape[1]: dep_loss = 0 else: dist = DependencyCRF(final, lengths=lengths) dep_labels = dist.struct.to_parts(b_tags, lengths=lengths).type_as(final) # [BATCH_SIZE, lengths, lengths] log_prob = dist.log_prob(dep_labels) dep_loss = log_prob.mean() #sum() if dep_loss < 0 : loss = loss_main -dep_loss/dep_loss_factor else: loss = loss_main loss.backward() tr_loss += loss.item() nb_tr_steps += 1 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step()