def prepare_config_and_inputs(self):
        input_ids = ids_tensor([self.batch_size, self.seq_length],
                               self.vocab_size)
        input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()

        input_lengths = None
        if self.use_input_lengths:
            input_lengths = (ids_tensor([self.batch_size], vocab_size=2) +
                             self.seq_length - 2
                             )  # small variation of seq_length

        token_type_ids = None
        if self.use_token_type_ids:
            token_type_ids = ids_tensor([self.batch_size, self.seq_length],
                                        self.n_langs)

        sequence_labels = None
        token_labels = None
        is_impossible_labels = None
        if self.use_labels:
            sequence_labels = ids_tensor([self.batch_size],
                                         self.type_sequence_label_size)
            token_labels = ids_tensor([self.batch_size, self.seq_length],
                                      self.num_labels)
            is_impossible_labels = ids_tensor([self.batch_size], 2).float()
            choice_labels = ids_tensor([self.batch_size], self.num_choices)

        config = FlaubertConfig(
            vocab_size=self.vocab_size,
            n_special=self.n_special,
            emb_dim=self.hidden_size,
            n_layers=self.num_hidden_layers,
            n_heads=self.num_attention_heads,
            dropout=self.hidden_dropout_prob,
            attention_dropout=self.attention_probs_dropout_prob,
            gelu_activation=self.gelu_activation,
            sinusoidal_embeddings=self.sinusoidal_embeddings,
            asm=self.asm,
            causal=self.causal,
            n_langs=self.n_langs,
            max_position_embeddings=self.max_position_embeddings,
            initializer_range=self.initializer_range,
            summary_type=self.summary_type,
            use_proj=self.use_proj,
            return_dict=True,
        )

        return (
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            choice_labels,
            input_mask,
        )
 def get_config(self):
     return FlaubertConfig(
         vocab_size=self.vocab_size,
         n_special=self.n_special,
         emb_dim=self.hidden_size,
         n_layers=self.num_hidden_layers,
         n_heads=self.num_attention_heads,
         dropout=self.hidden_dropout_prob,
         attention_dropout=self.attention_probs_dropout_prob,
         gelu_activation=self.gelu_activation,
         sinusoidal_embeddings=self.sinusoidal_embeddings,
         asm=self.asm,
         causal=self.causal,
         n_langs=self.n_langs,
         max_position_embeddings=self.max_position_embeddings,
         initializer_range=self.initializer_range,
         summary_type=self.summary_type,
         use_proj=self.use_proj,
     )
    def create(cls,
               model_type='camem',
               model_name="camembert-base",
               embedding_size=768,
               hidden_dim=512,
               rnn_layers=1,
               lstm_dropout=0.5,
               device="cuda",
               mode="weighted",
               key_dim=64,
               val_dim=64,
               num_heads=3,
               attn_dropout=0.3,
               self_attention=False,
               is_require_grad=False):
        configuration = {
            'model_type': model_type,
            "model_name": model_name,
            "device": device,
            "mode": mode,
            "self_attention": self_attention,
            "is_freeze": is_require_grad
        }

        if 'camem' in model_type:
            config_bert = CamembertConfig.from_pretrained(
                model_name, output_hidden_states=True)
            model = CamembertModel.from_pretrained(model_name,
                                                   config=config_bert)
            model.to(device)
        elif 'flaubert' in model_type:
            config_bert = FlaubertConfig.from_pretrained(
                model_name, output_hidden_states=True)
            model = FlaubertModel.from_pretrained(model_name,
                                                  config=config_bert)
            model.to(device)
        elif 'XLMRoberta' in model_type:
            config_bert = XLMRobertaConfig.from_pretrained(
                model_name, output_hidden_states=True)
            model = XLMRobertaModel.from_pretrained(model_name,
                                                    config=config_bert)
            model.to(device)
        elif 'M-Bert' in model_type:
            config_bert = BertConfig.from_pretrained(model_name,
                                                     output_hidden_states=True)
            model = BertModel.from_pretrained(model_name, config=config_bert)
            model.to(device)

        lstm = BiLSTM.create(embedding_size=embedding_size,
                             hidden_dim=hidden_dim,
                             rnn_layers=rnn_layers,
                             dropout=lstm_dropout)

        attn = MultiHeadAttention(key_dim, val_dim, hidden_dim, num_heads,
                                  attn_dropout)
        model.train()
        self = cls(model=model, config=configuration, lstm=lstm, attn=attn)
        # if is_freeze:
        self.freeze()

        return self