def get_model(model_type,
                  toy_classifier=False,
                  dry_run=False,
                  n_heads=1,
                  state_dict=None,
                  cache_dir=None):

        if dry_run:
            model = BertForSequenceClassification(BertConfig.dummy_config(
                len(tokenizer.vocab)),
                                                  num_labels=num_labels)
        else:
            model = BertForSequenceClassification.from_pretrained(
                model_type,
                cache_dir=cache_dir,
                num_labels=num_labels,
                state_dict=None if toy_classifier else state_dict,
            )
        if toy_classifier:
            config = BertConfig(len(tokenizer.vocab),
                                hidden_size=768,
                                num_hidden_layers=1,
                                num_attention_heads=n_heads,
                                intermediate_size=3072,
                                hidden_act="gelu",
                                hidden_dropout_prob=0.1,
                                attention_probs_dropout_prob=0.1,
                                max_position_embeddings=512,
                                type_vocab_size=2,
                                initializer_range=0.02)
            toy_model = BertForSequenceClassification(config,
                                                      num_labels=num_labels)
            toy_model.bert.embeddings.load_state_dict(
                model.bert.embeddings.state_dict())
            if state_dict is not None:
                model_to_load = getattr(toy_model, "module", toy_model)
                model_to_load.load_state_dict(state_dict)
            model = toy_model

        return model