Beispiel #1
0
def train(model, tokenizer, coarse_train_dataloader,
          coarse_validation_dataloader, doc_start_ind, all_labels, device,
          pad_token_dict):
    def calculate_loss(lm_logits, b_labels, b_input_mask, doc_start_ind):
        loss_fct = CrossEntropyLoss()
        batch_size = lm_logits.shape[0]
        logits_collected = []
        labels_collected = []
        for b in range(batch_size):
            logits_ind = lm_logits[b, :, :]  # seq_len x |V|
            labels_ind = b_labels[b, :]  # seq_len
            mask = b_input_mask[b, :] > 0
            maski = mask.unsqueeze(-1).expand_as(logits_ind)
            # unpad_seq_len x |V|
            logits_pad_removed = torch.masked_select(logits_ind, maski).view(
                -1, logits_ind.size(-1))
            labels_pad_removed = torch.masked_select(labels_ind,
                                                     mask)  # unpad_seq_len

            shift_logits = logits_pad_removed[doc_start_ind -
                                              1:-1, :].contiguous()
            shift_labels = labels_pad_removed[doc_start_ind:].contiguous()
            # Flatten the tokens
            logits_collected.append(
                shift_logits.view(-1, shift_logits.size(-1)))
            labels_collected.append(shift_labels.view(-1))

        logits_collected = torch.cat(logits_collected, dim=0)
        labels_collected = torch.cat(labels_collected, dim=0)
        loss = loss_fct(logits_collected, labels_collected)
        return loss

    optimizer = AdamW(
        model.parameters(),
        lr=5e-4,  # args.learning_rate - default is 5e-5, our notebook had 2e-5
        eps=1e-8  # args.adam_epsilon  - default is 1e-8.
    )

    sample_every = 100
    warmup_steps = 1e2
    epochs = 5
    total_steps = len(coarse_train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=total_steps)
    seed_val = 42
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)

    training_stats = []
    total_t0 = time.time()

    for epoch_i in range(0, epochs):
        print("", flush=True)
        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs),
              flush=True)
        print('Training...', flush=True)
        t0 = time.time()
        total_train_loss = 0
        model.train()

        for step, batch in enumerate(coarse_train_dataloader):
            if step % sample_every == 0 and not step == 0:
                elapsed = format_time(time.time() - t0)
                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(
                    step, len(coarse_train_dataloader), elapsed),
                      flush=True)
                model.eval()
                lbl = random.choice(all_labels)
                temp_list = ["<|labelpad|>"] * pad_token_dict[lbl]
                if len(temp_list) > 0:
                    label_str = " ".join(
                        lbl.split("_")) + " " + " ".join(temp_list)
                else:
                    label_str = " ".join(lbl.split("_"))
                text = tokenizer.bos_token + " " + label_str + " <|labelsep|> "
                sample_outputs = model.generate(input_ids=tokenizer.encode(
                    text, return_tensors='pt').to(device),
                                                do_sample=True,
                                                top_k=50,
                                                max_length=200,
                                                top_p=0.95,
                                                num_return_sequences=1)
                for i, sample_output in enumerate(sample_outputs):
                    print("{}: {}".format(i, tokenizer.decode(sample_output)),
                          flush=True)
                model.train()

            b_input_ids = batch[0].to(device)
            b_labels = batch[0].to(device)
            b_input_mask = batch[1].to(device)

            model.zero_grad()

            outputs = model(b_input_ids,
                            token_type_ids=None,
                            attention_mask=b_input_mask,
                            labels=b_labels)

            loss = calculate_loss(outputs[1], b_labels, b_input_mask,
                                  doc_start_ind)
            # loss = outputs[0]
            total_train_loss += loss.item()

            loss.backward()
            optimizer.step()
            scheduler.step()

        # **********************************

        # Calculate the average loss over all of the batches.
        avg_train_loss = total_train_loss / len(coarse_train_dataloader)

        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)

        print("", flush=True)
        print("  Average training loss: {0:.2f}".format(avg_train_loss),
              flush=True)
        print("  Training epcoh took: {:}".format(training_time), flush=True)

        # ========================================
        #               Validation
        # ========================================
        # After the completion of each training epoch, measure our performance on
        # our validation set.

        print("", flush=True)
        print("Running Validation...", flush=True)

        t0 = time.time()

        model.eval()

        total_eval_loss = 0
        nb_eval_steps = 0

        # Evaluate data for one epoch
        for batch in coarse_validation_dataloader:
            b_input_ids = batch[0].to(device)
            b_labels = batch[0].to(device)
            b_input_mask = batch[1].to(device)

            with torch.no_grad():
                outputs = model(b_input_ids,
                                token_type_ids=None,
                                attention_mask=b_input_mask,
                                labels=b_labels)

            # Accumulate the validation loss.
            loss = calculate_loss(outputs[1], b_labels, b_input_mask,
                                  doc_start_ind)
            # loss = outputs[0]
            total_eval_loss += loss.item()

        # Calculate the average loss over all of the batches.
        avg_val_loss = total_eval_loss / (len(coarse_validation_dataloader))

        # Measure how long the validation run took.
        validation_time = format_time(time.time() - t0)

        print("  Validation Loss: {0:.2f}".format(avg_val_loss), flush=True)
        print("  Validation took: {:}".format(validation_time), flush=True)

        # Record all statistics from this epoch.
        training_stats.append({
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Training Time': training_time,
            'Validation Time': validation_time
        })

    print("", flush=True)
    print("Training complete!", flush=True)

    print("Total training took {:} (h:mm:ss)".format(
        format_time(time.time() - total_t0)),
          flush=True)
    return model
def train(coarse_model, fine_model, coarse_tokenizer, fine_tokenizer, train_dataloader, validation_dataloader,
          label_to_exclusive_dataloader, doc_start_ind, index_to_label, device, secondary_device):
    def calculate_kl_div_loss(batch_fine_probs, batch_coarse_probs, batch_fine_input_masks, batch_coarse_input_masks,
                              batch_fine_input_ids, batch_coarse_input_ids, coarse_tokenizer, fine_tokenizer,
                              doc_start_ind):
        # Remove pad tokens
        # consider from doc_start_ind - 1
        loss_fct = torch.nn.KLDivLoss(reduction="batchmean")
        batch_size = batch_fine_probs.shape[0]
        losses = []
        for b in range(batch_size):
            fine_logits_ind = batch_fine_probs[b, :, :]  # seq_len x |V|
            coarse_logits_ind = batch_coarse_probs[b, :, :]  # seq_len x |V|
            fine_mask = batch_fine_input_masks[b, :] > 0
            coarse_mask = batch_coarse_input_masks[b, :] > 0
            if not torch.all(fine_mask.eq(coarse_mask)):
                print("Fine sentence", fine_tokenizer.decode(batch_fine_input_ids[b, :]))
                print("Coarse sentence", coarse_tokenizer.decode(batch_coarse_input_ids[b, :]))
                raise Exception("Fine and Coarse mask is not same")

            fine_dec_sent = fine_tokenizer.decode(batch_fine_input_ids[b, :][doc_start_ind:])
            coarse_dec_sent = coarse_tokenizer.decode(batch_coarse_input_ids[b, :][doc_start_ind:])

            if fine_dec_sent != coarse_dec_sent:
                print("Fine sentence ", fine_tokenizer.decode(batch_fine_input_ids[b, :][doc_start_ind:]))
                print("Coarse sentence ", coarse_tokenizer.decode(batch_coarse_input_ids[b, :][doc_start_ind:]))
                raise Exception("Fine and Coarse decoded sentence is not same")

            fine_maski = fine_mask.unsqueeze(-1).expand_as(fine_logits_ind)
            coarse_maski = coarse_mask.unsqueeze(-1).expand_as(coarse_logits_ind)
            # unpad_seq_len x |V|
            fine_logits_pad_removed = torch.masked_select(fine_logits_ind, fine_maski).view(-1,
                                                                                            fine_logits_ind.size(-1))
            coarse_logits_pad_removed = torch.masked_select(coarse_logits_ind, coarse_maski).view(-1,
                                                                                                  coarse_logits_ind.size(
                                                                                                      -1))
            shift_fine_logits = fine_logits_pad_removed[doc_start_ind - 1:-1, :].contiguous()
            shift_coarse_logits = coarse_logits_pad_removed[doc_start_ind - 1:-1, :].contiguous()
            # Compute loss here of shift_fine_logits and shift_coarse_logits append to losses
            loss = loss_fct(shift_fine_logits, shift_coarse_logits).unsqueeze(0)
            losses.append(loss)

        # Return mean of losses here
        losses = torch.cat(losses, dim=0)
        return losses.mean()

    def calculate_cross_entropy_loss(fine_model, label_to_exclusive_dataloader, doc_start_ind, device):
        loss_function = CrossEntropyLoss()

        b_labels_list = []
        b_input_ids_list = []
        b_input_mask_list = []
        scores_list = []

        # for l in label_to_exclusive_dataloader:
        selected_labs = random.sample(list(label_to_exclusive_dataloader.keys()), 1)
        for l in selected_labs:
            # print("Label", l)
            dataloader = label_to_exclusive_dataloader[l]
            it = 0
            for step, batch in dataloader:
                # print("Step for exc", step, it)
                b_input_ids = batch[0].to(device)
                b_labels = batch[0].to(device)
                b_input_mask = batch[1].to(device)

                outputs = fine_model(b_input_ids,
                                     token_type_ids=None,
                                     attention_mask=b_input_mask,
                                     labels=b_labels)
                b_labels_list.append(b_labels)
                b_input_ids_list.append(b_input_ids)
                b_input_mask_list.append(b_input_mask)
                scores_list.append(outputs[1])
                # reporter = MemReporter()
                # reporter.report()
                it += 1
                if it == 1:
                    break

        b_labels_tensor = torch.cat(b_labels_list, dim=0)
        b_input_ids_tensor = torch.cat(b_input_ids_list, dim=0)
        b_input_mask_tensor = torch.cat(b_input_mask_list, dim=0)
        scores_tensor = torch.cat(scores_list, dim=0)

        assert b_labels_tensor.shape[0] == b_input_ids_tensor.shape[0] == b_input_mask_tensor.shape[0] == \
               scores_tensor.shape[0]
        batch_size = scores_tensor.shape[0]
        logits_collected = []
        labels_collected = []
        for b in range(batch_size):
            logits_ind = scores_tensor[b, :, :]  # seq_len x |V|
            labels_ind = b_labels_tensor[b, :]  # seq_len
            mask = b_input_mask_tensor[b, :] > 0
            maski = mask.unsqueeze(-1).expand_as(logits_ind)
            # unpad_seq_len x |V|
            logits_pad_removed = torch.masked_select(logits_ind, maski).view(-1, logits_ind.size(-1))
            labels_pad_removed = torch.masked_select(labels_ind, mask)  # unpad_seq_len

            shift_logits = logits_pad_removed[doc_start_ind - 1:-1, :].contiguous()
            shift_labels = labels_pad_removed[doc_start_ind:].contiguous()
            # Flatten the tokens
            logits_collected.append(shift_logits.view(-1, shift_logits.size(-1)))
            labels_collected.append(shift_labels.view(-1))

        logits_collected = torch.cat(logits_collected, dim=0)
        labels_collected = torch.cat(labels_collected, dim=0)
        loss = loss_function(logits_collected, labels_collected).to(device)
        return loss

    def calculate_loss(batch_fine_probs,
                       batch_coarse_probs,
                       batch_fine_input_masks,
                       batch_coarse_input_masks,
                       batch_fine_input_ids,
                       batch_coarse_input_ids,
                       coarse_tokenizer,
                       fine_tokenizer,
                       fine_model,
                       label_to_exclusive_dataloader,
                       doc_start_ind,
                       device,
                       lambda_1=5,
                       is_val=False):
        kl_div_loss = calculate_kl_div_loss(batch_fine_probs, batch_coarse_probs, batch_fine_input_masks,
                                            batch_coarse_input_masks, batch_fine_input_ids, batch_coarse_input_ids,
                                            coarse_tokenizer, fine_tokenizer, doc_start_ind)
        # del batch_fine_probs
        # del batch_coarse_probs
        # del batch_fine_input_masks
        # del batch_coarse_input_masks
        # del batch_fine_input_ids
        # del batch_coarse_input_ids
        # torch.cuda.empty_cache()
        if not is_val:
            cross_ent_loss = calculate_cross_entropy_loss(fine_model, label_to_exclusive_dataloader, doc_start_ind,
                                                          device)
            print("KL-loss", kl_div_loss.item(), "CE-loss", cross_ent_loss.item())
        else:
            cross_ent_loss = 0
            print("KL-loss", kl_div_loss.item(), "CE-loss", cross_ent_loss)
        return (1 - lambda_1) * kl_div_loss + lambda_1 * cross_ent_loss

    def compute_lambda(step, max_steps):
        temp = 1 - step / max_steps
        if temp < 0:
            return 0
        else:
            return temp

    # epsilon = 1e-20  # Defined to avoid log probability getting undefined.
    fine_posterior = torch.nn.Parameter(torch.ones(len(index_to_label)).to(device))
    optimizer = AdamW(list(fine_model.parameters()) + [fine_posterior],
                      lr=5e-4,  # args.learning_rate - default is 5e-5, our notebook had 2e-5
                      eps=1e-8  # args.adam_epsilon  - default is 1e-8.
                      )
    sample_every = 100
    warmup_steps = 1e2
    epochs = 2
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=total_steps)
    seed_val = 42
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)

    training_stats = []
    total_t0 = time.time()

    coarse_model.eval()
    global_step = 0

    for epoch_i in range(0, epochs):
        print("", flush=True)
        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs), flush=True)
        print('Training...', flush=True)
        t0 = time.time()
        total_train_loss = 0
        fine_model.train()

        for step, batch in enumerate(train_dataloader):
            # batch contains -> coarse_input_ids, coarse_attention_masks, fine_input_ids, fine_attention_masks
            if step % sample_every == 0 and not step == 0:
                elapsed = format_time(time.time() - t0)
                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed),
                      flush=True)
                fine_model.eval()
                lbl = random.choice(list(index_to_label.values()))
                temp_list = ["<|labelpad|>"] * pad_token_dict[lbl]
                if len(temp_list) > 0:
                    label_str = " ".join(lbl.split("_")) + " " + " ".join(temp_list)
                else:
                    label_str = " ".join(lbl.split("_"))
                text = fine_tokenizer.bos_token + " " + label_str + " <|labelsep|> "
                sample_outputs = fine_model.generate(
                    input_ids=fine_tokenizer.encode(text, return_tensors='pt').to(device),
                    do_sample=True,
                    top_k=50,
                    max_length=200,
                    top_p=0.95,
                    num_return_sequences=1
                )
                for i, sample_output in enumerate(sample_outputs):
                    print("{}: {}".format(i, fine_tokenizer.decode(sample_output)), flush=True)
                fine_model.train()

            fine_posterior_log_probs = torch.log_softmax(fine_posterior, dim=0)
            print(torch.softmax(fine_posterior, dim=0), flush=True)

            b_coarse_input_ids = batch[0].to(secondary_device)
            b_coarse_labels = batch[0].to(secondary_device)
            b_coarse_input_mask = batch[1].to(secondary_device)

            b_size = b_coarse_input_ids.shape[0]

            b_fine_input_ids_minibatch = batch[2].to(device)
            b_fine_input_mask_minibatch = batch[3].to(device)

            coarse_model.zero_grad()
            # fine_model.zero_grad()
            optimizer.zero_grad()

            outputs = coarse_model(b_coarse_input_ids,
                                   token_type_ids=None,
                                   attention_mask=b_coarse_input_mask,
                                   labels=b_coarse_labels)

            batch_coarse_probs = torch.softmax(outputs[1], dim=-1).to(device)  # (b_size, seq_len, |V|)
            b_coarse_input_ids = b_coarse_input_ids.to(device)
            b_coarse_input_mask = b_coarse_input_mask.to(device)

            batch_fine_probs = []
            batch_fine_input_masks = []
            batch_fine_input_ids = []
            for b_ind in range(b_size):
                fine_label_sum_log_probs = []
                for l_ind in index_to_label:
                    b_fine_input_ids = b_fine_input_ids_minibatch[b_ind, l_ind, :].unsqueeze(0).to(device)
                    b_fine_labels = b_fine_input_ids_minibatch[b_ind, l_ind, :].unsqueeze(0).to(device)
                    b_fine_input_mask = b_fine_input_mask_minibatch[b_ind, l_ind, :].unsqueeze(0).to(device)

                    outputs = fine_model(b_fine_input_ids,
                                         token_type_ids=None,
                                         attention_mask=b_fine_input_mask,
                                         labels=b_fine_labels)

                    b_fine_labels = b_fine_labels.to(secondary_device)

                    fine_log_probs = torch.log_softmax(outputs[1], dim=-1)
                    fine_label_sum_log_probs.append((fine_log_probs + fine_posterior_log_probs[l_ind]))

                fine_label_sum_log_probs = torch.cat(fine_label_sum_log_probs, dim=0)  # (|F|, seq_len, |V|)
                batch_fine_probs.append(fine_label_sum_log_probs.unsqueeze(0))
                batch_fine_input_ids.append(b_fine_input_ids)
                batch_fine_input_masks.append(b_fine_input_mask)

            batch_fine_probs = torch.cat(batch_fine_probs, dim=0)  # (b_size, |F|, seq_len, |V|)
            batch_fine_input_masks = torch.cat(batch_fine_input_masks, dim=0)  # (b_size, seq_len)
            batch_fine_input_ids = torch.cat(batch_fine_input_ids, dim=0)  # (b_size, seq_len)
            batch_fine_log_probs = torch.logsumexp(batch_fine_probs, dim=1)  # This computes logsum_i P(f_i|c) P(D|f_i)

            loss = calculate_loss(batch_fine_log_probs,
                                  batch_coarse_probs,
                                  batch_fine_input_masks,
                                  b_coarse_input_mask,
                                  batch_fine_input_ids,
                                  b_coarse_input_ids,
                                  coarse_tokenizer,
                                  fine_tokenizer,
                                  fine_model,
                                  label_to_exclusive_dataloader,
                                  doc_start_ind,
                                  device,
                                  lambda_1=compute_lambda(global_step, max_steps=len(train_dataloader) * epochs))
            # loss = criterion(batch_fine_probs.log(), batch_coarse_probs.detach()).sum(dim=-1).mean(dim=-1).mean(dim=-1)
            total_train_loss += loss.item()
            print("Loss:", loss.item(), flush=True)

            loss.backward()
            optimizer.step()
            scheduler.step()
            global_step += 1

        # Calculate the average loss over all of the batches.
        avg_train_loss = total_train_loss / len(train_dataloader)

        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)

        print("", flush=True)
        print("  Average training loss: {0:.2f}".format(avg_train_loss), flush=True)
        print("  Training epoch took: {:}".format(training_time), flush=True)

        # ========================================
        #               Validation
        # ========================================
        # After the completion of each training epoch, measure our performance on
        # our validation set.

        print("", flush=True)
        print("Running Validation...", flush=True)

        t0 = time.time()

        fine_model.eval()

        total_eval_loss = 0
        nb_eval_steps = 0

        # Evaluate data for one epoch
        for batch in validation_dataloader:
            # batch contains -> coarse_input_ids, coarse_attention_masks, fine_input_ids, fine_attention_masks
            b_coarse_input_ids = batch[0].to(secondary_device)
            b_coarse_labels = batch[0].to(secondary_device)
            b_coarse_input_mask = batch[1].to(secondary_device)

            b_size = b_coarse_input_ids.shape[0]

            b_fine_input_ids_minibatch = batch[2].to(device)
            b_fine_input_mask_minibatch = batch[3].to(device)

            with torch.no_grad():
                fine_posterior_log_probs = torch.log_softmax(fine_posterior, dim=0)
                outputs = coarse_model(b_coarse_input_ids,
                                       token_type_ids=None,
                                       attention_mask=b_coarse_input_mask,
                                       labels=b_coarse_labels)

                batch_coarse_probs = torch.softmax(outputs[1], dim=-1).to(device)  # (b_size, seq_len, |V|)

                b_coarse_input_ids = b_coarse_input_ids.to(device)
                b_coarse_input_mask = b_coarse_input_mask.to(device)

                batch_fine_probs = []
                batch_fine_input_masks = []
                batch_fine_input_ids = []
                for b_ind in range(b_size):
                    fine_label_sum_log_probs = []
                    for l_ind in index_to_label:
                        b_fine_input_ids = b_fine_input_ids_minibatch[b_ind, l_ind, :].unsqueeze(0).to(device)
                        b_fine_labels = b_fine_input_ids_minibatch[b_ind, l_ind, :].unsqueeze(0).to(device)
                        b_fine_input_mask = b_fine_input_mask_minibatch[b_ind, l_ind, :].unsqueeze(0).to(device)

                        outputs = fine_model(b_fine_input_ids,
                                             token_type_ids=None,
                                             attention_mask=b_fine_input_mask,
                                             labels=b_fine_labels)
                        fine_log_probs = torch.log_softmax(outputs[1], dim=-1)
                        fine_label_sum_log_probs.append((fine_log_probs + fine_posterior_log_probs[l_ind]))

                    fine_label_sum_log_probs = torch.cat(fine_label_sum_log_probs, dim=0)  # (|F|, seq_len, |V|)
                    batch_fine_probs.append(fine_label_sum_log_probs.unsqueeze(0))
                    batch_fine_input_ids.append(b_fine_input_ids)
                    batch_fine_input_masks.append(b_fine_input_mask)

                batch_fine_probs = torch.cat(batch_fine_probs, dim=0)  # (b_size, |F|, seq_len, |V|)
                batch_fine_input_masks = torch.cat(batch_fine_input_masks, dim=0)  # (b_size, seq_len)
                batch_fine_input_ids = torch.cat(batch_fine_input_ids, dim=0)  # (b_size, seq_len)
                batch_fine_log_probs = torch.logsumexp(batch_fine_probs,
                                                       dim=1)  # This computes logsum_i P(f_i|c) P(D|f_i)

            # Accumulate the validation loss.
            loss = calculate_loss(batch_fine_log_probs,
                                  batch_coarse_probs,
                                  batch_fine_input_masks,
                                  b_coarse_input_mask,
                                  batch_fine_input_ids,
                                  b_coarse_input_ids,
                                  coarse_tokenizer,
                                  fine_tokenizer,
                                  fine_model,
                                  label_to_exclusive_dataloader,
                                  doc_start_ind,
                                  device,
                                  is_val=True,
                                  lambda_1=compute_lambda(global_step, max_steps=len(train_dataloader) * epochs))
            total_eval_loss += loss.item()

        # Calculate the average loss over all of the batches.
        avg_val_loss = total_eval_loss / len(validation_dataloader)

        # Measure how long the validation run took.
        validation_time = format_time(time.time() - t0)

        print("  Validation Loss: {0:.2f}".format(avg_val_loss), flush=True)
        print("  Validation took: {:}".format(validation_time), flush=True)

        # Record all statistics from this epoch.
        training_stats.append(
            {
                'epoch': epoch_i + 1,
                'Training Loss': avg_train_loss,
                'Valid. Loss': avg_val_loss,
                'Training Time': training_time,
                'Validation Time': validation_time
            }
        )

        # todo make temp_df, fine_input_ids, fine_attention_masks class variables.
        # true, preds, _ = test(fine_model, fine_posterior, fine_input_ids, fine_attention_masks, doc_start_ind,
        #                       index_to_label, label_to_index, list(temp_df.label.values), device)

    print("", flush=True)
    print("Training complete!", flush=True)

    print("Total training took {:} (h:mm:ss)".format(format_time(time.time() - total_t0)), flush=True)
    return fine_posterior, fine_model
def train(model, tokenizer, coarse_train_dataloader,
          coarse_validation_dataloader, fine_train_dataloader,
          fine_validation_dataloader, doc_start_ind, parent_labels,
          child_labels, device):
    def calculate_ce_loss(lm_logits, b_labels, b_input_mask, doc_start_ind):
        loss_fct = CrossEntropyLoss()
        batch_size = lm_logits.shape[0]
        logits_collected = []
        labels_collected = []
        for b in range(batch_size):
            logits_ind = lm_logits[b, :, :]  # seq_len x |V|
            labels_ind = b_labels[b, :]  # seq_len
            mask = b_input_mask[b, :] > 0
            maski = mask.unsqueeze(-1).expand_as(logits_ind)
            # unpad_seq_len x |V|
            logits_pad_removed = torch.masked_select(logits_ind, maski).view(
                -1, logits_ind.size(-1))
            labels_pad_removed = torch.masked_select(labels_ind,
                                                     mask)  # unpad_seq_len

            shift_logits = logits_pad_removed[doc_start_ind -
                                              1:-1, :].contiguous()
            shift_labels = labels_pad_removed[doc_start_ind:].contiguous()
            # Flatten the tokens
            logits_collected.append(
                shift_logits.view(-1, shift_logits.size(-1)))
            labels_collected.append(shift_labels.view(-1))

        logits_collected = torch.cat(logits_collected, dim=0)
        labels_collected = torch.cat(labels_collected, dim=0)
        loss = loss_fct(logits_collected, labels_collected)
        return loss

    def calculate_hinge_loss(fine_log_probs, other_log_probs):
        loss_fct = MarginRankingLoss(margin=1.609)
        length = len(other_log_probs)
        temp_tensor = []
        for i in range(length):
            temp_tensor.append(fine_log_probs)
        temp_tensor = torch.cat(temp_tensor, dim=0)
        other_log_probs = torch.cat(other_log_probs, dim=0)
        y_vec = torch.ones(length).to(device)
        loss = loss_fct(temp_tensor, other_log_probs, y_vec)
        return loss

    def calculate_loss(lm_logits,
                       b_labels,
                       b_input_mask,
                       doc_start_ind,
                       fine_log_probs,
                       other_log_probs,
                       lambda_1=0.01,
                       is_fine=True):
        ce_loss = calculate_ce_loss(lm_logits, b_labels, b_input_mask,
                                    doc_start_ind)
        if is_fine:
            hinge_loss = calculate_hinge_loss(fine_log_probs, other_log_probs)
            print("CE-loss",
                  ce_loss.item(),
                  "Hinge-loss",
                  hinge_loss.item(),
                  flush=True)
        else:
            hinge_loss = 0
            print("CE-loss",
                  ce_loss.item(),
                  "Hinge-loss",
                  hinge_loss,
                  flush=True)
        return ce_loss + lambda_1 * hinge_loss

    optimizer = AdamW(
        model.parameters(),
        lr=5e-4,  # args.learning_rate - default is 5e-5, our notebook had 2e-5
        eps=1e-8  # args.adam_epsilon  - default is 1e-8.
    )

    sample_every = 100
    warmup_steps = 1e2
    epochs = 5
    total_steps = (len(coarse_train_dataloader) +
                   len(fine_train_dataloader)) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=total_steps)
    seed_val = 81
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)

    training_stats = []
    total_t0 = time.time()

    for epoch_i in range(0, epochs):
        print("", flush=True)
        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs),
              flush=True)
        print('Training...', flush=True)
        t0 = time.time()
        total_train_loss = 0
        model.train()

        for step, batch in enumerate(coarse_train_dataloader):
            if step % sample_every == 0 and not step == 0:
                elapsed = format_time(time.time() - t0)
                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(
                    step, len(coarse_train_dataloader), elapsed),
                      flush=True)
                model.eval()
                lbl = random.choice(parent_labels)
                temp_list = ["<|labelpad|>"] * pad_token_dict[lbl]
                if len(temp_list) > 0:
                    label_str = " ".join(
                        lbl.split("_")) + " " + " ".join(temp_list)
                else:
                    label_str = " ".join(lbl.split("_"))
                text = tokenizer.bos_token + " " + label_str + " <|labelsep|> "
                sample_outputs = model.generate(input_ids=tokenizer.encode(
                    text, return_tensors='pt').to(device),
                                                do_sample=True,
                                                top_k=50,
                                                max_length=200,
                                                top_p=0.95,
                                                num_return_sequences=1)
                for i, sample_output in enumerate(sample_outputs):
                    print("{}: {}".format(i, tokenizer.decode(sample_output)),
                          flush=True)
                model.train()

            b_input_ids = batch[0].to(device)
            b_labels = batch[0].to(device)
            b_input_mask = batch[1].to(device)

            model.zero_grad()

            outputs = model(b_input_ids,
                            token_type_ids=None,
                            attention_mask=b_input_mask,
                            labels=b_labels)

            loss = calculate_loss(outputs[1],
                                  b_labels,
                                  b_input_mask,
                                  doc_start_ind,
                                  None,
                                  None,
                                  is_fine=False)
            # loss = outputs[0]
            total_train_loss += loss.item()

            loss.backward()
            optimizer.step()
            scheduler.step()

        for step, batch in enumerate(fine_train_dataloader):
            # batch contains -> fine_input_ids mini batch, fine_attention_masks mini batch
            if step % sample_every == 0 and not step == 0:
                elapsed = format_time(time.time() - t0)
                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(
                    step, len(fine_train_dataloader), elapsed),
                      flush=True)
                model.eval()
                lbl = random.choice(child_labels)
                temp_list = ["<|labelpad|>"] * pad_token_dict[lbl]
                if len(temp_list) > 0:
                    label_str = " ".join(
                        lbl.split("_")) + " " + " ".join(temp_list)
                else:
                    label_str = " ".join(lbl.split("_"))
                text = tokenizer.bos_token + " " + label_str + " <|labelsep|> "
                sample_outputs = model.generate(input_ids=tokenizer.encode(
                    text, return_tensors='pt').to(device),
                                                do_sample=True,
                                                top_k=50,
                                                max_length=200,
                                                top_p=0.95,
                                                num_return_sequences=1)
                for i, sample_output in enumerate(sample_outputs):
                    print("{}: {}".format(i, tokenizer.decode(sample_output)),
                          flush=True)
                model.train()

            b_fine_input_ids_minibatch = batch[0].to(device)
            b_fine_input_mask_minibatch = batch[1].to(device)

            b_size = b_fine_input_ids_minibatch.shape[0]
            assert b_size == 1
            mini_batch_size = b_fine_input_ids_minibatch.shape[1]

            model.zero_grad()

            batch_other_log_probs = []
            prev_mask = None

            for b_ind in range(b_size):
                for mini_batch_ind in range(mini_batch_size):
                    b_fine_input_ids = b_fine_input_ids_minibatch[
                        b_ind, mini_batch_ind, :].unsqueeze(0).to(device)
                    b_fine_labels = b_fine_input_ids_minibatch[
                        b_ind, mini_batch_ind, :].unsqueeze(0).to(device)
                    b_fine_input_mask = b_fine_input_mask_minibatch[
                        b_ind, mini_batch_ind, :].unsqueeze(0).to(device)
                    outputs = model(b_fine_input_ids,
                                    token_type_ids=None,
                                    attention_mask=b_fine_input_mask,
                                    labels=b_fine_labels)
                    log_probs = torch.log_softmax(outputs[1], dim=-1)
                    doc_prob = compute_doc_prob(log_probs, b_fine_input_mask,
                                                b_fine_labels,
                                                doc_start_ind).unsqueeze(0)
                    if mini_batch_ind == 0:
                        batch_fine_log_probs = doc_prob
                        orig_output = outputs
                        orig_labels = b_fine_labels
                        orig_mask = b_fine_input_mask
                    else:
                        batch_other_log_probs.append(doc_prob)
                    if prev_mask is not None:
                        assert torch.all(b_fine_input_mask.eq(prev_mask))
                    prev_mask = b_fine_input_mask

            loss = calculate_loss(orig_output[1],
                                  orig_labels,
                                  orig_mask,
                                  doc_start_ind,
                                  batch_fine_log_probs,
                                  batch_other_log_probs,
                                  is_fine=True)
            # loss = criterion(batch_fine_probs.log(), batch_coarse_probs.detach()).sum(dim=-1).mean(dim=-1).mean(dim=-1)
            total_train_loss += loss.item()
            print("Loss:", loss.item(), flush=True)

            loss.backward()
            optimizer.step()
            scheduler.step()

        # **********************************

        # Calculate the average loss over all of the batches.
        avg_train_loss = total_train_loss / (len(coarse_train_dataloader) +
                                             len(fine_train_dataloader))

        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)

        print("", flush=True)
        print("  Average training loss: {0:.2f}".format(avg_train_loss),
              flush=True)
        print("  Training epcoh took: {:}".format(training_time), flush=True)

        # ========================================
        #               Validation
        # ========================================
        # After the completion of each training epoch, measure our performance on
        # our validation set.

        print("", flush=True)
        print("Running Validation...", flush=True)

        t0 = time.time()

        model.eval()

        total_eval_loss = 0
        nb_eval_steps = 0

        # Evaluate data for one epoch
        for batch in coarse_validation_dataloader:
            b_input_ids = batch[0].to(device)
            b_labels = batch[0].to(device)
            b_input_mask = batch[1].to(device)

            with torch.no_grad():
                outputs = model(b_input_ids,
                                token_type_ids=None,
                                attention_mask=b_input_mask,
                                labels=b_labels)

            # Accumulate the validation loss.
            loss = calculate_loss(outputs[1],
                                  b_labels,
                                  b_input_mask,
                                  doc_start_ind,
                                  None,
                                  None,
                                  is_fine=False)
            # loss = outputs[0]
            total_eval_loss += loss.item()

        for batch in fine_validation_dataloader:
            # batch contains -> fine_input_ids mini batch, fine_attention_masks mini batch
            b_fine_input_ids_minibatch = batch[0].to(device)
            b_fine_input_mask_minibatch = batch[1].to(device)

            b_size = b_fine_input_ids_minibatch.shape[0]
            assert b_size == 1
            mini_batch_size = b_fine_input_ids_minibatch.shape[1]

            with torch.no_grad():
                batch_other_log_probs = []
                prev_mask = None

                for b_ind in range(b_size):
                    for mini_batch_ind in range(mini_batch_size):
                        b_fine_input_ids = b_fine_input_ids_minibatch[
                            b_ind, mini_batch_ind, :].unsqueeze(0).to(device)
                        b_fine_labels = b_fine_input_ids_minibatch[
                            b_ind, mini_batch_ind, :].unsqueeze(0).to(device)
                        b_fine_input_mask = b_fine_input_mask_minibatch[
                            b_ind, mini_batch_ind, :].unsqueeze(0).to(device)
                        outputs = model(b_fine_input_ids,
                                        token_type_ids=None,
                                        attention_mask=b_fine_input_mask,
                                        labels=b_fine_labels)
                        log_probs = torch.log_softmax(outputs[1], dim=-1)
                        doc_prob = compute_doc_prob(log_probs,
                                                    b_fine_input_mask,
                                                    b_fine_labels,
                                                    doc_start_ind).unsqueeze(0)
                        if mini_batch_ind == 0:
                            batch_fine_log_probs = doc_prob
                            orig_output = outputs
                            orig_labels = b_fine_labels
                            orig_mask = b_fine_input_mask
                        else:
                            batch_other_log_probs.append(doc_prob)
                        if prev_mask is not None:
                            assert torch.all(b_fine_input_mask.eq(prev_mask))
                        prev_mask = b_fine_input_mask

            loss = calculate_loss(orig_output[1],
                                  orig_labels,
                                  orig_mask,
                                  doc_start_ind,
                                  batch_fine_log_probs,
                                  batch_other_log_probs,
                                  is_fine=True)
            total_eval_loss += loss.item()

        # Calculate the average loss over all of the batches.
        avg_val_loss = total_eval_loss / (len(coarse_validation_dataloader) +
                                          len(fine_validation_dataloader))

        # Measure how long the validation run took.
        validation_time = format_time(time.time() - t0)

        print("  Validation Loss: {0:.2f}".format(avg_val_loss), flush=True)
        print("  Validation took: {:}".format(validation_time), flush=True)

        # Record all statistics from this epoch.
        training_stats.append({
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Training Time': training_time,
            'Validation Time': validation_time
        })

    print("", flush=True)
    print("Training complete!", flush=True)

    print("Total training took {:} (h:mm:ss)".format(
        format_time(time.time() - total_t0)),
          flush=True)
    return model