Exemple #1
0
def main():
    args = _get_parser().parse_args()

    args.device_ids = list(map(int, args.device_ids.split(',')))

    set_seed(args)
    sanity_checks(args)
    init_gpu_params(args)

    tokenizer = Tokenizer(
        os.path.join(args.bert_model, "senti_vocab.txt"),
        os.path.join(args.bert_model, "RoBERTa_Sentiment_kor"))

    train_dataset = NSMCDataSet(data_split="train",
                                tokenizer=tokenizer,
                                max_seq_length=args.max_seq_length,
                                pad_to_max=args.pad_to_max)

    train_sampler = RandomSampler(
        train_dataset) if not args.multi_gpu else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.per_gpu_train_batch_size,
                                  collate_fn=train_dataset.collate_fn)

    model = RobertaForSequenceClassification(
        classifier_dropout=args.classifier_dropout,
        bert_model_dir=args.bert_model,
        pre_trained_model=args.pretrained_bert_model)

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    t_total = len(train_dataloader
                  ) // args.gradient_accumulation_steps * args.num_train_epochs
    warmup_steps = math.ceil(t_total * args.warmup_proportion)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=t_total)

    model.zero_grad()
    model.cuda()

    if args.multi_gpu:
        model = DistributedDataParallel(
            model,
            device_ids=[args.device_ids[args.local_rank]],
            output_device=args.device_ids[args.local_rank])

    if args.is_master:
        logger.info(json.dumps(vars(args), indent=4))
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Num Epochs = %d", args.num_train_epochs)
        logger.info("  Instantaneous batch size per GPU = %d",
                    args.per_gpu_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            args.per_gpu_train_batch_size * args.gradient_accumulation_steps *
            (torch.distributed.get_world_size() if args.multi_gpu else 1),
        )
        logger.info("  Gradient Accumulation steps = %d",
                    args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

    global_steps = 0
    for epoch in range(args.num_train_epochs):
        if args.multi_gpu:
            train_sampler.set_epoch(epoch)

        loss_bce = nn.BCEWithLogitsLoss()

        iter_loss = 0

        model.train()

        pbar = tqdm(train_dataloader, desc="Iter", disable=not args.is_master)
        for step, batch in enumerate(pbar):
            input_ids, attention_mask, labels = batch

            inputs = {
                "input_ids":
                torch.tensor(input_ids, dtype=torch.long).cuda(),
                "attention_mask":
                torch.tensor(attention_mask, dtype=torch.long).cuda()
            }
            logits = model(**inputs)

            labels = torch.tensor(labels, dtype=torch.float).cuda()

            loss = loss_bce(input=logits.view(-1), target=labels.view(-1))
            if args.gradient_accumulation_steps > 1:
                loss /= args.gradient_accumulation_steps
            loss.backward()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.max_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()

                global_steps += 1
                if global_steps % args.save_checkpoints_steps == 0 and args.is_master:
                    model_to_save = model.module if hasattr(
                        model, 'module') else model
                    save_path = os.path.join(args.save_checkpoints_dir,
                                             f"step_{global_steps}.ckpt")
                    torch.save(model_to_save.state_dict(), save_path)

            iter_loss += loss.item()
            pbar.set_postfix({
                "epoch":
                epoch,
                "global_steps":
                global_steps,
                "learning_rate":
                f"{scheduler.get_last_lr()[0]:.10f}",
                "avg_iter_loss":
                f"{iter_loss / (step + 1) * args.gradient_accumulation_steps:.5f}",
                "last_loss":
                f"{loss.item() * args.gradient_accumulation_steps:.5f}"
            })
        pbar.close()

        if args.is_master:
            model_to_save = model.module if hasattr(model, 'module') else model
            save_path = os.path.join(args.save_checkpoints_dir,
                                     f"epoch_{epoch+1}.ckpt")
            torch.save(model_to_save.state_dict(), save_path)
Exemple #2
0
def main():
    parser = argparse.ArgumentParser(description="Training")
    parser.add_argument("--force",
                        action="store_true",
                        help="Overwrite dump_path if it already exists.")

    parser.add_argument(
        "--dump_path",
        type=str,
        required=True,
        help="The output directory (log, checkpoints, parameters, etc.)")
    parser.add_argument(
        "--data_file",
        type=str,
        required=True,
        help=
        "The binarized file (tokenized + tokens_to_ids) and grouped by sequence.",
    )

    parser.add_argument(
        "--student_type",
        type=str,
        choices=["distilbert", "roberta", "gpt2"],
        required=True,
        help="The student type (DistilBERT, RoBERTa).",
    )
    parser.add_argument("--student_config",
                        type=str,
                        required=True,
                        help="Path to the student configuration.")
    parser.add_argument("--student_pretrained_weights",
                        default=None,
                        type=str,
                        help="Load student initialization checkpoint.")

    parser.add_argument("--teacher_type",
                        choices=["bert", "roberta", "gpt2"],
                        required=True,
                        help="Teacher type (BERT, RoBERTa).")
    parser.add_argument("--teacher_name",
                        type=str,
                        required=True,
                        help="The teacher model.")

    parser.add_argument("--temperature",
                        default=2.0,
                        type=float,
                        help="Temperature for the softmax temperature.")
    parser.add_argument(
        "--alpha_ce",
        default=0.5,
        type=float,
        help="Linear weight for the distillation loss. Must be >=0.")
    parser.add_argument(
        "--alpha_mlm",
        default=0.0,
        type=float,
        help=
        "Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.",
    )
    parser.add_argument("--alpha_clm",
                        default=0.5,
                        type=float,
                        help="Linear weight for the CLM loss. Must be >=0.")
    parser.add_argument("--alpha_mse",
                        default=0.0,
                        type=float,
                        help="Linear weight of the MSE loss. Must be >=0.")
    parser.add_argument(
        "--alpha_cos",
        default=0.0,
        type=float,
        help="Linear weight of the cosine embedding loss. Must be >=0.")

    parser.add_argument(
        "--mlm",
        action="store_true",
        help=
        "The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM.")
    parser.add_argument(
        "--mlm_mask_prop",
        default=0.15,
        type=float,
        help="Proportion of tokens for which we need to make a prediction.",
    )
    parser.add_argument("--word_mask",
                        default=0.8,
                        type=float,
                        help="Proportion of tokens to mask out.")
    parser.add_argument("--word_keep",
                        default=0.1,
                        type=float,
                        help="Proportion of tokens to keep.")
    parser.add_argument("--word_rand",
                        default=0.1,
                        type=float,
                        help="Proportion of tokens to randomly replace.")
    parser.add_argument(
        "--mlm_smoothing",
        default=0.7,
        type=float,
        help=
        "Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).",
    )
    parser.add_argument("--token_counts",
                        type=str,
                        help="The token counts in the data_file for MLM.")

    parser.add_argument(
        "--restrict_ce_to_mask",
        action="store_true",
        help=
        "If true, compute the distilation loss only the [MLM] prediction distribution.",
    )
    parser.add_argument(
        "--freeze_pos_embs",
        action="store_true",
        help=
        "Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.",
    )
    parser.add_argument(
        "--freeze_token_type_embds",
        action="store_true",
        help=
        "Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.",
    )

    parser.add_argument("--n_epoch",
                        type=int,
                        default=3,
                        help="Number of pass on the whole dataset.")
    parser.add_argument("--batch_size",
                        type=int,
                        default=5,
                        help="Batch size (for each process).")
    parser.add_argument(
        "--group_by_size",
        action="store_false",
        help=
        "If true, group sequences that have similar length into the same batch. Default is true.",
    )

    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=50,
        help="Gradient accumulation for larger training batches.",
    )
    parser.add_argument("--warmup_prop",
                        default=0.05,
                        type=float,
                        help="Linear warmup proportion.")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--learning_rate",
                        default=5e-4,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--adam_epsilon",
                        default=1e-6,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=5.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--initializer_range",
                        default=0.02,
                        type=float,
                        help="Random initialization range.")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--n_gpu",
                        type=int,
                        default=1,
                        help="Number of GPUs in the node.")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="Distributed training - Local rank")
    parser.add_argument("--seed", type=int, default=56, help="Random seed")

    parser.add_argument("--log_interval",
                        type=int,
                        default=500,
                        help="Tensorboard logging interval.")
    parser.add_argument("--checkpoint_interval",
                        type=int,
                        default=4000,
                        help="Checkpoint interval.")
    args = parser.parse_args()
    sanity_checks(args)

    # ARGS #
    init_gpu_params(args)
    set_seed(args)
    if args.is_master:
        if os.path.exists(args.dump_path):
            if not args.force:
                raise ValueError(
                    f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it"
                    "Use `--force` if you want to overwrite it")
            else:
                shutil.rmtree(args.dump_path)

        if not os.path.exists(args.dump_path):
            os.makedirs(args.dump_path)
        logger.info(
            f"Experiment will be dumped and logged in {args.dump_path}")

        # SAVE PARAMS #
        logger.info(f"Param: {args}")
        with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
            json.dump(vars(args), f, indent=4)
        git_log(args.dump_path)

    student_config_class, student_model_class, _ = MODEL_CLASSES[
        args.student_type]
    teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[
        args.teacher_type]

    # TOKENIZER #
    tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
    special_tok_ids = {}
    for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
        idx = tokenizer.all_special_tokens.index(tok_symbol)
        special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
    logger.info(f"Special tokens {special_tok_ids}")
    args.special_tok_ids = special_tok_ids
    args.max_model_input_size = tokenizer.max_model_input_sizes[
        args.teacher_name]

    # DATA LOADER #
    logger.info(f"Loading data from {args.data_file}")
    with open(args.data_file, "rb") as fp:
        data = pickle.load(fp)

    if args.mlm:
        logger.info(
            f"Loading token counts from {args.token_counts} (already pre-computed)"
        )
        with open(args.token_counts, "rb") as fp:
            counts = pickle.load(fp)

        token_probs = np.maximum(counts, 1)**-args.mlm_smoothing
        for idx in special_tok_ids.values():
            token_probs[idx] = 0.0  # do not predict special tokens
        token_probs = torch.from_numpy(token_probs)
    else:
        token_probs = None

    train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
    logger.info(f"Data loader created.")

    # STUDENT #
    logger.info(f"Loading student config from {args.student_config}")
    stu_architecture_config = student_config_class.from_pretrained(
        args.student_config)
    stu_architecture_config.output_hidden_states = True

    if args.student_pretrained_weights is not None:
        logger.info(
            f"Loading pretrained weights from {args.student_pretrained_weights}"
        )
        student = student_model_class.from_pretrained(
            args.student_pretrained_weights, config=stu_architecture_config)
    else:
        student = student_model_class(stu_architecture_config)

    if args.n_gpu > 0:
        student.to(f"cuda:{args.local_rank}")
    logger.info(f"Student loaded.")

    # TEACHER #
    teacher = teacher_model_class.from_pretrained(args.teacher_name,
                                                  output_hidden_states=True)
    if args.n_gpu > 0:
        teacher.to(f"cuda:{args.local_rank}")
    logger.info(f"Teacher loaded from {args.teacher_name}.")

    # FREEZING #
    if args.freeze_pos_embs:
        freeze_pos_embeddings(student, args)
    if args.freeze_token_type_embds:
        freeze_token_type_embeddings(student, args)

    # SANITY CHECKS #
    assert student.config.vocab_size == teacher.config.vocab_size
    assert student.config.hidden_size == teacher.config.hidden_size
    assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
    if args.mlm:
        assert token_probs.size(0) == stu_architecture_config.vocab_size

    # DISTILLER #
    torch.cuda.empty_cache()
    distiller = Distiller(params=args,
                          dataset=train_lm_seq_dataset,
                          token_probs=token_probs,
                          student=student,
                          teacher=teacher)
    distiller.train()
    logger.info("Let's go get some drinks.")
Exemple #3
0
def main():
    parser = argparse.ArgumentParser(description="Training")

    parser.add_argument(
        "--dump_path",
        type=str,
        required=True,
        help="The output directory (log, checkpoints, parameters, etc.)")
    parser.add_argument(
        "--data_file",
        type=str,
        required=True,
        help=
        "The binarized file (tokenized + tokens_to_ids) and grouped by sequence."
    )
    parser.add_argument("--token_counts",
                        type=str,
                        required=True,
                        help="The token counts in the data_file for MLM.")
    parser.add_argument("--force",
                        action='store_true',
                        help="Overwrite dump_path if it already exists.")

    parser.add_argument("--vocab_size",
                        default=30522,
                        type=int,
                        help="The vocabulary size.")
    parser.add_argument(
        "--max_position_embeddings",
        default=512,
        type=int,
        help="Maximum sequence length we can model (including [CLS] and [SEP])."
    )
    parser.add_argument(
        "--sinusoidal_pos_embds",
        action='store_false',
        help=
        "If true, the position embeddings are simply fixed with sinusoidal embeddings."
    )
    parser.add_argument("--n_layers",
                        default=6,
                        type=int,
                        help="Number of Transformer blocks.")
    parser.add_argument("--n_heads",
                        default=12,
                        type=int,
                        help="Number of heads in the self-attention module.")
    parser.add_argument(
        "--dim",
        default=768,
        type=int,
        help="Dimension through the network. Must be divisible by n_heads")
    parser.add_argument("--hidden_dim",
                        default=3072,
                        type=int,
                        help="Intermediate dimension in the FFN.")
    parser.add_argument("--dropout", default=0.1, type=float, help="Dropout.")
    parser.add_argument("--attention_dropout",
                        default=0.1,
                        type=float,
                        help="Dropout in self-attention.")
    parser.add_argument("--activation",
                        default='gelu',
                        type=str,
                        help="Activation to use in self-attention")
    parser.add_argument(
        "--tie_weights_",
        action='store_false',
        help=
        "If true, we tie the embeddings matrix with the projection over the vocabulary matrix. Default is true."
    )

    parser.add_argument("--from_pretrained_weights",
                        default=None,
                        type=str,
                        help="Load student initialization checkpoint.")
    parser.add_argument(
        "--from_pretrained_config",
        default=None,
        type=str,
        help="Load student initialization architecture config.")
    parser.add_argument("--teacher_type",
                        default="bert",
                        choices=["bert", "roberta"],
                        help="Teacher type (BERT, RoBERTa).")
    parser.add_argument("--teacher_name",
                        default="bert-base-uncased",
                        type=str,
                        help="The teacher model.")

    parser.add_argument("--temperature",
                        default=2.,
                        type=float,
                        help="Temperature for the softmax temperature.")
    parser.add_argument(
        "--alpha_ce",
        default=0.5,
        type=float,
        help="Linear weight for the distillation loss. Must be >=0.")
    parser.add_argument("--alpha_mlm",
                        default=0.5,
                        type=float,
                        help="Linear weight for the MLM loss. Must be >=0.")
    parser.add_argument("--alpha_mse",
                        default=0.0,
                        type=float,
                        help="Linear weight of the MSE loss. Must be >=0.")
    parser.add_argument(
        "--alpha_cos",
        default=0.0,
        type=float,
        help="Linear weight of the cosine embedding loss. Must be >=0.")
    parser.add_argument(
        "--mlm_mask_prop",
        default=0.15,
        type=float,
        help="Proportion of tokens for which we need to make a prediction.")
    parser.add_argument("--word_mask",
                        default=0.8,
                        type=float,
                        help="Proportion of tokens to mask out.")
    parser.add_argument("--word_keep",
                        default=0.1,
                        type=float,
                        help="Proportion of tokens to keep.")
    parser.add_argument("--word_rand",
                        default=0.1,
                        type=float,
                        help="Proportion of tokens to randomly replace.")
    parser.add_argument(
        "--mlm_smoothing",
        default=0.7,
        type=float,
        help=
        "Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec)."
    )
    parser.add_argument(
        "--restrict_ce_to_mask",
        action='store_true',
        help=
        "If true, compute the distilation loss only the [MLM] prediction distribution."
    )

    parser.add_argument("--n_epoch",
                        type=int,
                        default=3,
                        help="Number of pass on the whole dataset.")
    parser.add_argument("--batch_size",
                        type=int,
                        default=5,
                        help="Batch size (for each process).")
    parser.add_argument(
        "--tokens_per_batch",
        type=int,
        default=-1,
        help=
        "If specified, modify the batches so that they have approximately this number of tokens."
    )
    parser.add_argument(
        "--shuffle",
        action='store_false',
        help="If true, shuffle the sequence order. Default is true.")
    parser.add_argument(
        "--group_by_size",
        action='store_false',
        help=
        "If true, group sequences that have similar length into the same batch. Default is true."
    )

    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=50,
        help="Gradient accumulation for larger training batches.")
    parser.add_argument("--warmup_prop",
                        default=0.05,
                        type=float,
                        help="Linear warmup proportion.")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--learning_rate",
                        default=5e-4,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--adam_epsilon",
                        default=1e-6,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=5.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--initializer_range",
                        default=0.02,
                        type=float,
                        help="Random initialization range.")

    parser.add_argument(
        '--fp16',
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--n_gpu",
                        type=int,
                        default=1,
                        help="Number of GPUs in the node.")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="Distributed training - Local rank")
    parser.add_argument("--seed", type=int, default=56, help="Random seed")

    parser.add_argument("--log_interval",
                        type=int,
                        default=500,
                        help="Tensorboard logging interval.")
    parser.add_argument("--checkpoint_interval",
                        type=int,
                        default=4000,
                        help="Checkpoint interval.")
    args = parser.parse_args()

    ## ARGS ##
    init_gpu_params(args)
    set_seed(args)
    if args.is_master:
        if os.path.exists(args.dump_path):
            if not args.force:
                raise ValueError(
                    f'Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it'
                    'Use `--force` if you want to overwrite it')
            else:
                shutil.rmtree(args.dump_path)

        if not os.path.exists(args.dump_path):
            os.makedirs(args.dump_path)
        logger.info(
            f'Experiment will be dumped and logged in {args.dump_path}')

        ### SAVE PARAMS ###
        logger.info(f'Param: {args}')
        with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f:
            json.dump(vars(args), f, indent=4)
        git_log(args.dump_path)
    assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \
           (args.from_pretrained_weights is not None and args.from_pretrained_config is not None)

    ### TOKENIZER ###
    if args.teacher_type == 'bert':
        tokenizer = BertTokenizer.from_pretrained(args.teacher_name)
    elif args.teacher_type == 'roberta':
        tokenizer = RobertaTokenizer.from_pretrained(args.teacher_name)
    special_tok_ids = {}
    for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
        idx = tokenizer.all_special_tokens.index(tok_symbol)
        special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
    logger.info(f'Special tokens {special_tok_ids}')
    args.special_tok_ids = special_tok_ids

    ## DATA LOADER ##
    logger.info(f'Loading data from {args.data_file}')
    with open(args.data_file, 'rb') as fp:
        data = pickle.load(fp)

    assert os.path.isfile(args.token_counts)
    logger.info(
        f'Loading token counts from {args.token_counts} (already pre-computed)'
    )
    with open(args.token_counts, 'rb') as fp:
        counts = pickle.load(fp)
        assert len(counts) == args.vocab_size
    token_probs = np.maximum(counts, 1)**-args.mlm_smoothing
    for idx in special_tok_ids.values():
        token_probs[idx] = 0.  # do not predict special tokens
    token_probs = torch.from_numpy(token_probs)

    train_dataloader = Dataset(params=args, data=data)
    logger.info(f'Data loader created.')

    ## STUDENT ##
    if args.from_pretrained_weights is not None:
        assert os.path.isfile(args.from_pretrained_weights)
        assert os.path.isfile(args.from_pretrained_config)
        logger.info(
            f'Loading pretrained weights from {args.from_pretrained_weights}')
        logger.info(
            f'Loading pretrained config from {args.from_pretrained_config}')
        stu_architecture_config = DistilBertConfig.from_json_file(
            args.from_pretrained_config)
        stu_architecture_config.output_hidden_states = True
        student = DistilBertForMaskedLM.from_pretrained(
            args.from_pretrained_weights, config=stu_architecture_config)
    else:
        args.vocab_size_or_config_json_file = args.vocab_size
        stu_architecture_config = DistilBertConfig(**vars(args),
                                                   output_hidden_states=True)
        student = DistilBertForMaskedLM(stu_architecture_config)

    if args.n_gpu > 0:
        student.to(f'cuda:{args.local_rank}')
    logger.info(f'Student loaded.')

    ## TEACHER ##
    if args.teacher_type == 'bert':
        teacher = BertForMaskedLM.from_pretrained(args.teacher_name,
                                                  output_hidden_states=True)
    elif args.teacher_type == 'roberta':
        teacher = RobertaForMaskedLM.from_pretrained(args.teacher_name,
                                                     output_hidden_states=True)
    if args.n_gpu > 0:
        teacher.to(f'cuda:{args.local_rank}')
    logger.info(f'Teacher loaded from {args.teacher_name}.')

    ## DISTILLER ##
    torch.cuda.empty_cache()
    distiller = Distiller(params=args,
                          dataloader=train_dataloader,
                          token_probs=token_probs,
                          student=student,
                          teacher=teacher)
    distiller.train()
    logger.info("Let's go get some drinks.")