def train(model, tokenizer, coarse_train_dataloader, coarse_validation_dataloader, doc_start_ind, all_labels, device, pad_token_dict): def calculate_loss(lm_logits, b_labels, b_input_mask, doc_start_ind): loss_fct = CrossEntropyLoss() batch_size = lm_logits.shape[0] logits_collected = [] labels_collected = [] for b in range(batch_size): logits_ind = lm_logits[b, :, :] # seq_len x |V| labels_ind = b_labels[b, :] # seq_len mask = b_input_mask[b, :] > 0 maski = mask.unsqueeze(-1).expand_as(logits_ind) # unpad_seq_len x |V| logits_pad_removed = torch.masked_select(logits_ind, maski).view( -1, logits_ind.size(-1)) labels_pad_removed = torch.masked_select(labels_ind, mask) # unpad_seq_len shift_logits = logits_pad_removed[doc_start_ind - 1:-1, :].contiguous() shift_labels = labels_pad_removed[doc_start_ind:].contiguous() # Flatten the tokens logits_collected.append( shift_logits.view(-1, shift_logits.size(-1))) labels_collected.append(shift_labels.view(-1)) logits_collected = torch.cat(logits_collected, dim=0) labels_collected = torch.cat(labels_collected, dim=0) loss = loss_fct(logits_collected, labels_collected) return loss optimizer = AdamW( model.parameters(), lr=5e-4, # args.learning_rate - default is 5e-5, our notebook had 2e-5 eps=1e-8 # args.adam_epsilon - default is 1e-8. ) sample_every = 100 warmup_steps = 1e2 epochs = 5 total_steps = len(coarse_train_dataloader) * epochs scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) seed_val = 42 random.seed(seed_val) np.random.seed(seed_val) torch.manual_seed(seed_val) torch.cuda.manual_seed_all(seed_val) training_stats = [] total_t0 = time.time() for epoch_i in range(0, epochs): print("", flush=True) print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs), flush=True) print('Training...', flush=True) t0 = time.time() total_train_loss = 0 model.train() for step, batch in enumerate(coarse_train_dataloader): if step % sample_every == 0 and not step == 0: elapsed = format_time(time.time() - t0) print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format( step, len(coarse_train_dataloader), elapsed), flush=True) model.eval() lbl = random.choice(all_labels) temp_list = ["<|labelpad|>"] * pad_token_dict[lbl] if len(temp_list) > 0: label_str = " ".join( lbl.split("_")) + " " + " ".join(temp_list) else: label_str = " ".join(lbl.split("_")) text = tokenizer.bos_token + " " + label_str + " <|labelsep|> " sample_outputs = model.generate(input_ids=tokenizer.encode( text, return_tensors='pt').to(device), do_sample=True, top_k=50, max_length=200, top_p=0.95, num_return_sequences=1) for i, sample_output in enumerate(sample_outputs): print("{}: {}".format(i, tokenizer.decode(sample_output)), flush=True) model.train() b_input_ids = batch[0].to(device) b_labels = batch[0].to(device) b_input_mask = batch[1].to(device) model.zero_grad() outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) loss = calculate_loss(outputs[1], b_labels, b_input_mask, doc_start_ind) # loss = outputs[0] total_train_loss += loss.item() loss.backward() optimizer.step() scheduler.step() # ********************************** # Calculate the average loss over all of the batches. avg_train_loss = total_train_loss / len(coarse_train_dataloader) # Measure how long this epoch took. training_time = format_time(time.time() - t0) print("", flush=True) print(" Average training loss: {0:.2f}".format(avg_train_loss), flush=True) print(" Training epcoh took: {:}".format(training_time), flush=True) # ======================================== # Validation # ======================================== # After the completion of each training epoch, measure our performance on # our validation set. print("", flush=True) print("Running Validation...", flush=True) t0 = time.time() model.eval() total_eval_loss = 0 nb_eval_steps = 0 # Evaluate data for one epoch for batch in coarse_validation_dataloader: b_input_ids = batch[0].to(device) b_labels = batch[0].to(device) b_input_mask = batch[1].to(device) with torch.no_grad(): outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) # Accumulate the validation loss. loss = calculate_loss(outputs[1], b_labels, b_input_mask, doc_start_ind) # loss = outputs[0] total_eval_loss += loss.item() # Calculate the average loss over all of the batches. avg_val_loss = total_eval_loss / (len(coarse_validation_dataloader)) # Measure how long the validation run took. validation_time = format_time(time.time() - t0) print(" Validation Loss: {0:.2f}".format(avg_val_loss), flush=True) print(" Validation took: {:}".format(validation_time), flush=True) # Record all statistics from this epoch. training_stats.append({ 'epoch': epoch_i + 1, 'Training Loss': avg_train_loss, 'Valid. Loss': avg_val_loss, 'Training Time': training_time, 'Validation Time': validation_time }) print("", flush=True) print("Training complete!", flush=True) print("Total training took {:} (h:mm:ss)".format( format_time(time.time() - total_t0)), flush=True) return model
def train(coarse_model, fine_model, coarse_tokenizer, fine_tokenizer, train_dataloader, validation_dataloader, label_to_exclusive_dataloader, doc_start_ind, index_to_label, device, secondary_device): def calculate_kl_div_loss(batch_fine_probs, batch_coarse_probs, batch_fine_input_masks, batch_coarse_input_masks, batch_fine_input_ids, batch_coarse_input_ids, coarse_tokenizer, fine_tokenizer, doc_start_ind): # Remove pad tokens # consider from doc_start_ind - 1 loss_fct = torch.nn.KLDivLoss(reduction="batchmean") batch_size = batch_fine_probs.shape[0] losses = [] for b in range(batch_size): fine_logits_ind = batch_fine_probs[b, :, :] # seq_len x |V| coarse_logits_ind = batch_coarse_probs[b, :, :] # seq_len x |V| fine_mask = batch_fine_input_masks[b, :] > 0 coarse_mask = batch_coarse_input_masks[b, :] > 0 if not torch.all(fine_mask.eq(coarse_mask)): print("Fine sentence", fine_tokenizer.decode(batch_fine_input_ids[b, :])) print("Coarse sentence", coarse_tokenizer.decode(batch_coarse_input_ids[b, :])) raise Exception("Fine and Coarse mask is not same") fine_dec_sent = fine_tokenizer.decode(batch_fine_input_ids[b, :][doc_start_ind:]) coarse_dec_sent = coarse_tokenizer.decode(batch_coarse_input_ids[b, :][doc_start_ind:]) if fine_dec_sent != coarse_dec_sent: print("Fine sentence ", fine_tokenizer.decode(batch_fine_input_ids[b, :][doc_start_ind:])) print("Coarse sentence ", coarse_tokenizer.decode(batch_coarse_input_ids[b, :][doc_start_ind:])) raise Exception("Fine and Coarse decoded sentence is not same") fine_maski = fine_mask.unsqueeze(-1).expand_as(fine_logits_ind) coarse_maski = coarse_mask.unsqueeze(-1).expand_as(coarse_logits_ind) # unpad_seq_len x |V| fine_logits_pad_removed = torch.masked_select(fine_logits_ind, fine_maski).view(-1, fine_logits_ind.size(-1)) coarse_logits_pad_removed = torch.masked_select(coarse_logits_ind, coarse_maski).view(-1, coarse_logits_ind.size( -1)) shift_fine_logits = fine_logits_pad_removed[doc_start_ind - 1:-1, :].contiguous() shift_coarse_logits = coarse_logits_pad_removed[doc_start_ind - 1:-1, :].contiguous() # Compute loss here of shift_fine_logits and shift_coarse_logits append to losses loss = loss_fct(shift_fine_logits, shift_coarse_logits).unsqueeze(0) losses.append(loss) # Return mean of losses here losses = torch.cat(losses, dim=0) return losses.mean() def calculate_cross_entropy_loss(fine_model, label_to_exclusive_dataloader, doc_start_ind, device): loss_function = CrossEntropyLoss() b_labels_list = [] b_input_ids_list = [] b_input_mask_list = [] scores_list = [] # for l in label_to_exclusive_dataloader: selected_labs = random.sample(list(label_to_exclusive_dataloader.keys()), 1) for l in selected_labs: # print("Label", l) dataloader = label_to_exclusive_dataloader[l] it = 0 for step, batch in dataloader: # print("Step for exc", step, it) b_input_ids = batch[0].to(device) b_labels = batch[0].to(device) b_input_mask = batch[1].to(device) outputs = fine_model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) b_labels_list.append(b_labels) b_input_ids_list.append(b_input_ids) b_input_mask_list.append(b_input_mask) scores_list.append(outputs[1]) # reporter = MemReporter() # reporter.report() it += 1 if it == 1: break b_labels_tensor = torch.cat(b_labels_list, dim=0) b_input_ids_tensor = torch.cat(b_input_ids_list, dim=0) b_input_mask_tensor = torch.cat(b_input_mask_list, dim=0) scores_tensor = torch.cat(scores_list, dim=0) assert b_labels_tensor.shape[0] == b_input_ids_tensor.shape[0] == b_input_mask_tensor.shape[0] == \ scores_tensor.shape[0] batch_size = scores_tensor.shape[0] logits_collected = [] labels_collected = [] for b in range(batch_size): logits_ind = scores_tensor[b, :, :] # seq_len x |V| labels_ind = b_labels_tensor[b, :] # seq_len mask = b_input_mask_tensor[b, :] > 0 maski = mask.unsqueeze(-1).expand_as(logits_ind) # unpad_seq_len x |V| logits_pad_removed = torch.masked_select(logits_ind, maski).view(-1, logits_ind.size(-1)) labels_pad_removed = torch.masked_select(labels_ind, mask) # unpad_seq_len shift_logits = logits_pad_removed[doc_start_ind - 1:-1, :].contiguous() shift_labels = labels_pad_removed[doc_start_ind:].contiguous() # Flatten the tokens logits_collected.append(shift_logits.view(-1, shift_logits.size(-1))) labels_collected.append(shift_labels.view(-1)) logits_collected = torch.cat(logits_collected, dim=0) labels_collected = torch.cat(labels_collected, dim=0) loss = loss_function(logits_collected, labels_collected).to(device) return loss def calculate_loss(batch_fine_probs, batch_coarse_probs, batch_fine_input_masks, batch_coarse_input_masks, batch_fine_input_ids, batch_coarse_input_ids, coarse_tokenizer, fine_tokenizer, fine_model, label_to_exclusive_dataloader, doc_start_ind, device, lambda_1=5, is_val=False): kl_div_loss = calculate_kl_div_loss(batch_fine_probs, batch_coarse_probs, batch_fine_input_masks, batch_coarse_input_masks, batch_fine_input_ids, batch_coarse_input_ids, coarse_tokenizer, fine_tokenizer, doc_start_ind) # del batch_fine_probs # del batch_coarse_probs # del batch_fine_input_masks # del batch_coarse_input_masks # del batch_fine_input_ids # del batch_coarse_input_ids # torch.cuda.empty_cache() if not is_val: cross_ent_loss = calculate_cross_entropy_loss(fine_model, label_to_exclusive_dataloader, doc_start_ind, device) print("KL-loss", kl_div_loss.item(), "CE-loss", cross_ent_loss.item()) else: cross_ent_loss = 0 print("KL-loss", kl_div_loss.item(), "CE-loss", cross_ent_loss) return (1 - lambda_1) * kl_div_loss + lambda_1 * cross_ent_loss def compute_lambda(step, max_steps): temp = 1 - step / max_steps if temp < 0: return 0 else: return temp # epsilon = 1e-20 # Defined to avoid log probability getting undefined. fine_posterior = torch.nn.Parameter(torch.ones(len(index_to_label)).to(device)) optimizer = AdamW(list(fine_model.parameters()) + [fine_posterior], lr=5e-4, # args.learning_rate - default is 5e-5, our notebook had 2e-5 eps=1e-8 # args.adam_epsilon - default is 1e-8. ) sample_every = 100 warmup_steps = 1e2 epochs = 2 total_steps = len(train_dataloader) * epochs scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) seed_val = 42 random.seed(seed_val) np.random.seed(seed_val) torch.manual_seed(seed_val) torch.cuda.manual_seed_all(seed_val) training_stats = [] total_t0 = time.time() coarse_model.eval() global_step = 0 for epoch_i in range(0, epochs): print("", flush=True) print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs), flush=True) print('Training...', flush=True) t0 = time.time() total_train_loss = 0 fine_model.train() for step, batch in enumerate(train_dataloader): # batch contains -> coarse_input_ids, coarse_attention_masks, fine_input_ids, fine_attention_masks if step % sample_every == 0 and not step == 0: elapsed = format_time(time.time() - t0) print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(train_dataloader), elapsed), flush=True) fine_model.eval() lbl = random.choice(list(index_to_label.values())) temp_list = ["<|labelpad|>"] * pad_token_dict[lbl] if len(temp_list) > 0: label_str = " ".join(lbl.split("_")) + " " + " ".join(temp_list) else: label_str = " ".join(lbl.split("_")) text = fine_tokenizer.bos_token + " " + label_str + " <|labelsep|> " sample_outputs = fine_model.generate( input_ids=fine_tokenizer.encode(text, return_tensors='pt').to(device), do_sample=True, top_k=50, max_length=200, top_p=0.95, num_return_sequences=1 ) for i, sample_output in enumerate(sample_outputs): print("{}: {}".format(i, fine_tokenizer.decode(sample_output)), flush=True) fine_model.train() fine_posterior_log_probs = torch.log_softmax(fine_posterior, dim=0) print(torch.softmax(fine_posterior, dim=0), flush=True) b_coarse_input_ids = batch[0].to(secondary_device) b_coarse_labels = batch[0].to(secondary_device) b_coarse_input_mask = batch[1].to(secondary_device) b_size = b_coarse_input_ids.shape[0] b_fine_input_ids_minibatch = batch[2].to(device) b_fine_input_mask_minibatch = batch[3].to(device) coarse_model.zero_grad() # fine_model.zero_grad() optimizer.zero_grad() outputs = coarse_model(b_coarse_input_ids, token_type_ids=None, attention_mask=b_coarse_input_mask, labels=b_coarse_labels) batch_coarse_probs = torch.softmax(outputs[1], dim=-1).to(device) # (b_size, seq_len, |V|) b_coarse_input_ids = b_coarse_input_ids.to(device) b_coarse_input_mask = b_coarse_input_mask.to(device) batch_fine_probs = [] batch_fine_input_masks = [] batch_fine_input_ids = [] for b_ind in range(b_size): fine_label_sum_log_probs = [] for l_ind in index_to_label: b_fine_input_ids = b_fine_input_ids_minibatch[b_ind, l_ind, :].unsqueeze(0).to(device) b_fine_labels = b_fine_input_ids_minibatch[b_ind, l_ind, :].unsqueeze(0).to(device) b_fine_input_mask = b_fine_input_mask_minibatch[b_ind, l_ind, :].unsqueeze(0).to(device) outputs = fine_model(b_fine_input_ids, token_type_ids=None, attention_mask=b_fine_input_mask, labels=b_fine_labels) b_fine_labels = b_fine_labels.to(secondary_device) fine_log_probs = torch.log_softmax(outputs[1], dim=-1) fine_label_sum_log_probs.append((fine_log_probs + fine_posterior_log_probs[l_ind])) fine_label_sum_log_probs = torch.cat(fine_label_sum_log_probs, dim=0) # (|F|, seq_len, |V|) batch_fine_probs.append(fine_label_sum_log_probs.unsqueeze(0)) batch_fine_input_ids.append(b_fine_input_ids) batch_fine_input_masks.append(b_fine_input_mask) batch_fine_probs = torch.cat(batch_fine_probs, dim=0) # (b_size, |F|, seq_len, |V|) batch_fine_input_masks = torch.cat(batch_fine_input_masks, dim=0) # (b_size, seq_len) batch_fine_input_ids = torch.cat(batch_fine_input_ids, dim=0) # (b_size, seq_len) batch_fine_log_probs = torch.logsumexp(batch_fine_probs, dim=1) # This computes logsum_i P(f_i|c) P(D|f_i) loss = calculate_loss(batch_fine_log_probs, batch_coarse_probs, batch_fine_input_masks, b_coarse_input_mask, batch_fine_input_ids, b_coarse_input_ids, coarse_tokenizer, fine_tokenizer, fine_model, label_to_exclusive_dataloader, doc_start_ind, device, lambda_1=compute_lambda(global_step, max_steps=len(train_dataloader) * epochs)) # loss = criterion(batch_fine_probs.log(), batch_coarse_probs.detach()).sum(dim=-1).mean(dim=-1).mean(dim=-1) total_train_loss += loss.item() print("Loss:", loss.item(), flush=True) loss.backward() optimizer.step() scheduler.step() global_step += 1 # Calculate the average loss over all of the batches. avg_train_loss = total_train_loss / len(train_dataloader) # Measure how long this epoch took. training_time = format_time(time.time() - t0) print("", flush=True) print(" Average training loss: {0:.2f}".format(avg_train_loss), flush=True) print(" Training epoch took: {:}".format(training_time), flush=True) # ======================================== # Validation # ======================================== # After the completion of each training epoch, measure our performance on # our validation set. print("", flush=True) print("Running Validation...", flush=True) t0 = time.time() fine_model.eval() total_eval_loss = 0 nb_eval_steps = 0 # Evaluate data for one epoch for batch in validation_dataloader: # batch contains -> coarse_input_ids, coarse_attention_masks, fine_input_ids, fine_attention_masks b_coarse_input_ids = batch[0].to(secondary_device) b_coarse_labels = batch[0].to(secondary_device) b_coarse_input_mask = batch[1].to(secondary_device) b_size = b_coarse_input_ids.shape[0] b_fine_input_ids_minibatch = batch[2].to(device) b_fine_input_mask_minibatch = batch[3].to(device) with torch.no_grad(): fine_posterior_log_probs = torch.log_softmax(fine_posterior, dim=0) outputs = coarse_model(b_coarse_input_ids, token_type_ids=None, attention_mask=b_coarse_input_mask, labels=b_coarse_labels) batch_coarse_probs = torch.softmax(outputs[1], dim=-1).to(device) # (b_size, seq_len, |V|) b_coarse_input_ids = b_coarse_input_ids.to(device) b_coarse_input_mask = b_coarse_input_mask.to(device) batch_fine_probs = [] batch_fine_input_masks = [] batch_fine_input_ids = [] for b_ind in range(b_size): fine_label_sum_log_probs = [] for l_ind in index_to_label: b_fine_input_ids = b_fine_input_ids_minibatch[b_ind, l_ind, :].unsqueeze(0).to(device) b_fine_labels = b_fine_input_ids_minibatch[b_ind, l_ind, :].unsqueeze(0).to(device) b_fine_input_mask = b_fine_input_mask_minibatch[b_ind, l_ind, :].unsqueeze(0).to(device) outputs = fine_model(b_fine_input_ids, token_type_ids=None, attention_mask=b_fine_input_mask, labels=b_fine_labels) fine_log_probs = torch.log_softmax(outputs[1], dim=-1) fine_label_sum_log_probs.append((fine_log_probs + fine_posterior_log_probs[l_ind])) fine_label_sum_log_probs = torch.cat(fine_label_sum_log_probs, dim=0) # (|F|, seq_len, |V|) batch_fine_probs.append(fine_label_sum_log_probs.unsqueeze(0)) batch_fine_input_ids.append(b_fine_input_ids) batch_fine_input_masks.append(b_fine_input_mask) batch_fine_probs = torch.cat(batch_fine_probs, dim=0) # (b_size, |F|, seq_len, |V|) batch_fine_input_masks = torch.cat(batch_fine_input_masks, dim=0) # (b_size, seq_len) batch_fine_input_ids = torch.cat(batch_fine_input_ids, dim=0) # (b_size, seq_len) batch_fine_log_probs = torch.logsumexp(batch_fine_probs, dim=1) # This computes logsum_i P(f_i|c) P(D|f_i) # Accumulate the validation loss. loss = calculate_loss(batch_fine_log_probs, batch_coarse_probs, batch_fine_input_masks, b_coarse_input_mask, batch_fine_input_ids, b_coarse_input_ids, coarse_tokenizer, fine_tokenizer, fine_model, label_to_exclusive_dataloader, doc_start_ind, device, is_val=True, lambda_1=compute_lambda(global_step, max_steps=len(train_dataloader) * epochs)) total_eval_loss += loss.item() # Calculate the average loss over all of the batches. avg_val_loss = total_eval_loss / len(validation_dataloader) # Measure how long the validation run took. validation_time = format_time(time.time() - t0) print(" Validation Loss: {0:.2f}".format(avg_val_loss), flush=True) print(" Validation took: {:}".format(validation_time), flush=True) # Record all statistics from this epoch. training_stats.append( { 'epoch': epoch_i + 1, 'Training Loss': avg_train_loss, 'Valid. Loss': avg_val_loss, 'Training Time': training_time, 'Validation Time': validation_time } ) # todo make temp_df, fine_input_ids, fine_attention_masks class variables. # true, preds, _ = test(fine_model, fine_posterior, fine_input_ids, fine_attention_masks, doc_start_ind, # index_to_label, label_to_index, list(temp_df.label.values), device) print("", flush=True) print("Training complete!", flush=True) print("Total training took {:} (h:mm:ss)".format(format_time(time.time() - total_t0)), flush=True) return fine_posterior, fine_model
def train(model, tokenizer, coarse_train_dataloader, coarse_validation_dataloader, fine_train_dataloader, fine_validation_dataloader, doc_start_ind, parent_labels, child_labels, device): def calculate_ce_loss(lm_logits, b_labels, b_input_mask, doc_start_ind): loss_fct = CrossEntropyLoss() batch_size = lm_logits.shape[0] logits_collected = [] labels_collected = [] for b in range(batch_size): logits_ind = lm_logits[b, :, :] # seq_len x |V| labels_ind = b_labels[b, :] # seq_len mask = b_input_mask[b, :] > 0 maski = mask.unsqueeze(-1).expand_as(logits_ind) # unpad_seq_len x |V| logits_pad_removed = torch.masked_select(logits_ind, maski).view( -1, logits_ind.size(-1)) labels_pad_removed = torch.masked_select(labels_ind, mask) # unpad_seq_len shift_logits = logits_pad_removed[doc_start_ind - 1:-1, :].contiguous() shift_labels = labels_pad_removed[doc_start_ind:].contiguous() # Flatten the tokens logits_collected.append( shift_logits.view(-1, shift_logits.size(-1))) labels_collected.append(shift_labels.view(-1)) logits_collected = torch.cat(logits_collected, dim=0) labels_collected = torch.cat(labels_collected, dim=0) loss = loss_fct(logits_collected, labels_collected) return loss def calculate_hinge_loss(fine_log_probs, other_log_probs): loss_fct = MarginRankingLoss(margin=1.609) length = len(other_log_probs) temp_tensor = [] for i in range(length): temp_tensor.append(fine_log_probs) temp_tensor = torch.cat(temp_tensor, dim=0) other_log_probs = torch.cat(other_log_probs, dim=0) y_vec = torch.ones(length).to(device) loss = loss_fct(temp_tensor, other_log_probs, y_vec) return loss def calculate_loss(lm_logits, b_labels, b_input_mask, doc_start_ind, fine_log_probs, other_log_probs, lambda_1=0.01, is_fine=True): ce_loss = calculate_ce_loss(lm_logits, b_labels, b_input_mask, doc_start_ind) if is_fine: hinge_loss = calculate_hinge_loss(fine_log_probs, other_log_probs) print("CE-loss", ce_loss.item(), "Hinge-loss", hinge_loss.item(), flush=True) else: hinge_loss = 0 print("CE-loss", ce_loss.item(), "Hinge-loss", hinge_loss, flush=True) return ce_loss + lambda_1 * hinge_loss optimizer = AdamW( model.parameters(), lr=5e-4, # args.learning_rate - default is 5e-5, our notebook had 2e-5 eps=1e-8 # args.adam_epsilon - default is 1e-8. ) sample_every = 100 warmup_steps = 1e2 epochs = 5 total_steps = (len(coarse_train_dataloader) + len(fine_train_dataloader)) * epochs scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) seed_val = 81 random.seed(seed_val) np.random.seed(seed_val) torch.manual_seed(seed_val) torch.cuda.manual_seed_all(seed_val) training_stats = [] total_t0 = time.time() for epoch_i in range(0, epochs): print("", flush=True) print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs), flush=True) print('Training...', flush=True) t0 = time.time() total_train_loss = 0 model.train() for step, batch in enumerate(coarse_train_dataloader): if step % sample_every == 0 and not step == 0: elapsed = format_time(time.time() - t0) print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format( step, len(coarse_train_dataloader), elapsed), flush=True) model.eval() lbl = random.choice(parent_labels) temp_list = ["<|labelpad|>"] * pad_token_dict[lbl] if len(temp_list) > 0: label_str = " ".join( lbl.split("_")) + " " + " ".join(temp_list) else: label_str = " ".join(lbl.split("_")) text = tokenizer.bos_token + " " + label_str + " <|labelsep|> " sample_outputs = model.generate(input_ids=tokenizer.encode( text, return_tensors='pt').to(device), do_sample=True, top_k=50, max_length=200, top_p=0.95, num_return_sequences=1) for i, sample_output in enumerate(sample_outputs): print("{}: {}".format(i, tokenizer.decode(sample_output)), flush=True) model.train() b_input_ids = batch[0].to(device) b_labels = batch[0].to(device) b_input_mask = batch[1].to(device) model.zero_grad() outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) loss = calculate_loss(outputs[1], b_labels, b_input_mask, doc_start_ind, None, None, is_fine=False) # loss = outputs[0] total_train_loss += loss.item() loss.backward() optimizer.step() scheduler.step() for step, batch in enumerate(fine_train_dataloader): # batch contains -> fine_input_ids mini batch, fine_attention_masks mini batch if step % sample_every == 0 and not step == 0: elapsed = format_time(time.time() - t0) print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format( step, len(fine_train_dataloader), elapsed), flush=True) model.eval() lbl = random.choice(child_labels) temp_list = ["<|labelpad|>"] * pad_token_dict[lbl] if len(temp_list) > 0: label_str = " ".join( lbl.split("_")) + " " + " ".join(temp_list) else: label_str = " ".join(lbl.split("_")) text = tokenizer.bos_token + " " + label_str + " <|labelsep|> " sample_outputs = model.generate(input_ids=tokenizer.encode( text, return_tensors='pt').to(device), do_sample=True, top_k=50, max_length=200, top_p=0.95, num_return_sequences=1) for i, sample_output in enumerate(sample_outputs): print("{}: {}".format(i, tokenizer.decode(sample_output)), flush=True) model.train() b_fine_input_ids_minibatch = batch[0].to(device) b_fine_input_mask_minibatch = batch[1].to(device) b_size = b_fine_input_ids_minibatch.shape[0] assert b_size == 1 mini_batch_size = b_fine_input_ids_minibatch.shape[1] model.zero_grad() batch_other_log_probs = [] prev_mask = None for b_ind in range(b_size): for mini_batch_ind in range(mini_batch_size): b_fine_input_ids = b_fine_input_ids_minibatch[ b_ind, mini_batch_ind, :].unsqueeze(0).to(device) b_fine_labels = b_fine_input_ids_minibatch[ b_ind, mini_batch_ind, :].unsqueeze(0).to(device) b_fine_input_mask = b_fine_input_mask_minibatch[ b_ind, mini_batch_ind, :].unsqueeze(0).to(device) outputs = model(b_fine_input_ids, token_type_ids=None, attention_mask=b_fine_input_mask, labels=b_fine_labels) log_probs = torch.log_softmax(outputs[1], dim=-1) doc_prob = compute_doc_prob(log_probs, b_fine_input_mask, b_fine_labels, doc_start_ind).unsqueeze(0) if mini_batch_ind == 0: batch_fine_log_probs = doc_prob orig_output = outputs orig_labels = b_fine_labels orig_mask = b_fine_input_mask else: batch_other_log_probs.append(doc_prob) if prev_mask is not None: assert torch.all(b_fine_input_mask.eq(prev_mask)) prev_mask = b_fine_input_mask loss = calculate_loss(orig_output[1], orig_labels, orig_mask, doc_start_ind, batch_fine_log_probs, batch_other_log_probs, is_fine=True) # loss = criterion(batch_fine_probs.log(), batch_coarse_probs.detach()).sum(dim=-1).mean(dim=-1).mean(dim=-1) total_train_loss += loss.item() print("Loss:", loss.item(), flush=True) loss.backward() optimizer.step() scheduler.step() # ********************************** # Calculate the average loss over all of the batches. avg_train_loss = total_train_loss / (len(coarse_train_dataloader) + len(fine_train_dataloader)) # Measure how long this epoch took. training_time = format_time(time.time() - t0) print("", flush=True) print(" Average training loss: {0:.2f}".format(avg_train_loss), flush=True) print(" Training epcoh took: {:}".format(training_time), flush=True) # ======================================== # Validation # ======================================== # After the completion of each training epoch, measure our performance on # our validation set. print("", flush=True) print("Running Validation...", flush=True) t0 = time.time() model.eval() total_eval_loss = 0 nb_eval_steps = 0 # Evaluate data for one epoch for batch in coarse_validation_dataloader: b_input_ids = batch[0].to(device) b_labels = batch[0].to(device) b_input_mask = batch[1].to(device) with torch.no_grad(): outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) # Accumulate the validation loss. loss = calculate_loss(outputs[1], b_labels, b_input_mask, doc_start_ind, None, None, is_fine=False) # loss = outputs[0] total_eval_loss += loss.item() for batch in fine_validation_dataloader: # batch contains -> fine_input_ids mini batch, fine_attention_masks mini batch b_fine_input_ids_minibatch = batch[0].to(device) b_fine_input_mask_minibatch = batch[1].to(device) b_size = b_fine_input_ids_minibatch.shape[0] assert b_size == 1 mini_batch_size = b_fine_input_ids_minibatch.shape[1] with torch.no_grad(): batch_other_log_probs = [] prev_mask = None for b_ind in range(b_size): for mini_batch_ind in range(mini_batch_size): b_fine_input_ids = b_fine_input_ids_minibatch[ b_ind, mini_batch_ind, :].unsqueeze(0).to(device) b_fine_labels = b_fine_input_ids_minibatch[ b_ind, mini_batch_ind, :].unsqueeze(0).to(device) b_fine_input_mask = b_fine_input_mask_minibatch[ b_ind, mini_batch_ind, :].unsqueeze(0).to(device) outputs = model(b_fine_input_ids, token_type_ids=None, attention_mask=b_fine_input_mask, labels=b_fine_labels) log_probs = torch.log_softmax(outputs[1], dim=-1) doc_prob = compute_doc_prob(log_probs, b_fine_input_mask, b_fine_labels, doc_start_ind).unsqueeze(0) if mini_batch_ind == 0: batch_fine_log_probs = doc_prob orig_output = outputs orig_labels = b_fine_labels orig_mask = b_fine_input_mask else: batch_other_log_probs.append(doc_prob) if prev_mask is not None: assert torch.all(b_fine_input_mask.eq(prev_mask)) prev_mask = b_fine_input_mask loss = calculate_loss(orig_output[1], orig_labels, orig_mask, doc_start_ind, batch_fine_log_probs, batch_other_log_probs, is_fine=True) total_eval_loss += loss.item() # Calculate the average loss over all of the batches. avg_val_loss = total_eval_loss / (len(coarse_validation_dataloader) + len(fine_validation_dataloader)) # Measure how long the validation run took. validation_time = format_time(time.time() - t0) print(" Validation Loss: {0:.2f}".format(avg_val_loss), flush=True) print(" Validation took: {:}".format(validation_time), flush=True) # Record all statistics from this epoch. training_stats.append({ 'epoch': epoch_i + 1, 'Training Loss': avg_train_loss, 'Valid. Loss': avg_val_loss, 'Training Time': training_time, 'Validation Time': validation_time }) print("", flush=True) print("Training complete!", flush=True) print("Total training took {:} (h:mm:ss)".format( format_time(time.time() - total_t0)), flush=True) return model