def from_huggingface_model(cls, model: PreTrainedModel,
                               ffn_activation: str, ffn_dropout: float,
                               attention: Attention):
        config = model.config
        encoder = cls(n_layers=config.n_layers,
                      n_heads=config.n_heads,
                      dim=config.dim,
                      hidden_dim=config.hidden_dim,
                      ffn_activation=ffn_activation,
                      ffn_dropout=ffn_dropout,
                      attention=attention)
        # After creating the encoder, we copy weights over from the transformer.  This currently
        # requires that the internal structure of the text side of this encoder *exactly matches*
        # the internal structure of whatever transformer you're using.
        encoder_parameters = dict(encoder.named_parameters())
        for name, parameter in model.named_parameters():
            if name.startswith("transformer."):
                name = name.replace("LayerNorm", "layer_norm")
                if name not in encoder_parameters:
                    raise ValueError(
                        f"Couldn't find a matching parameter for {name}. Is this transformer "
                        "compatible with the joint encoder you're using?")
                encoder_parameters[name].data.copy_(parameter.data)

        return encoder
def train(model: PreTrainedModel,
          train_dataloader: DataLoader,
          dev_dataloader: DataLoader,
          batch_size: int,
          gradient_accumulation_steps: int,
          device,
          num_train_epochs: int = 20,
          warmup_proportion: float = 0.1,
          learning_rate: float = 1e-5,
          patience: int = 5,
          output_dir: str = "/tmp/",
          model_file_name: str = "model.bin") -> str:
    """
    Trains a BERT Model on a set of training data, tuning it on a set of development data

    Args:
        model: the model that will be trained
        train_dataloader: a DataLoader with training data
        dev_dataloader: a DataLoader with development data (for early stopping)
        batch_size: the batch size for training
        gradient_accumulation_steps: the number of steps that gradients will be accumulated
        device: the device where training will take place ("cpu" or "cuda")
        num_train_epochs: the maximum number of training epochs
        warmup_proportion: the proportion of training steps for which the learning rate will be warmed up
        learning_rate: the initial learning rate
        patience: the number of epochs after which training will stop if no improvement on the dev
                  set is observed
        output_dir: the directory where the model will be saved
        model_file_name: the filename of the model file

    Returns: the path to the trained model

    """
    def warmup_linear(x, warmup=0.002):
        if x < warmup:
            return x / warmup
        return 1.0 - x

    output_model_file = os.path.join(output_dir, model_file_name)

    num_train_steps = int(
        len(train_dataloader.dataset) / batch_size /
        gradient_accumulation_steps * num_train_epochs)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=learning_rate,
                      correct_bias=False)

    global_step = 0
    loss_history = []
    best_epoch = 0
    for epoch in trange(int(num_train_epochs), desc="Epoch"):

        model.train()
        tr_loss = 0
        for step, batch in enumerate(
                tqdm(train_dataloader, desc="Training iteration")):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch
            if type(model) == BertForSequenceClassification or type(
                    model) == BertForMultiLabelSequenceClassification:
                outputs = model(input_ids,
                                attention_mask=input_mask,
                                token_type_ids=segment_ids,
                                labels=label_ids)
            elif type(model) == DistilBertForSequenceClassification:
                outputs = model(input_ids,
                                attention_mask=input_mask,
                                labels=label_ids)
            loss = outputs[0]

            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps

            loss.backward()

            tr_loss += loss.item()

            if (step + 1) % gradient_accumulation_steps == 0:
                lr_this_step = learning_rate * warmup_linear(
                    global_step / num_train_steps, warmup_proportion)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

        dev_loss, _, _ = evaluate(model, dev_dataloader, device)

        print("Loss history:", loss_history)
        print("Dev loss:", dev_loss)

        if len(loss_history) == 0 or dev_loss < min(loss_history):
            model_to_save = model.module if hasattr(model, 'module') else model
            torch.save(model_to_save.state_dict(), output_model_file)
            best_epoch = epoch

        if epoch - best_epoch >= patience:
            print("No improvement on development set. Finish training.")
            break

        loss_history.append(dev_loss)

    return output_model_file