Exemple #1
0
def main():
    parser = argparse.ArgumentParser(description='pruning_one-step.py')
    parser.add_argument('-model_path',
                        default='../KD/models/bert_ft',
                        type=str,
                        help="distill type")
    parser.add_argument('-output_dir',
                        default='models/prun_bert',
                        type=str,
                        help="output dir")
    parser.add_argument('-task',
                        default='CoLA',
                        type=str,
                        help="Name of the task")
    parser.add_argument('-keep_heads',
                        type=int,
                        default=2,
                        help="the number of attention heads to keep")
    parser.add_argument('-ffn_hidden_dim',
                        type=int,
                        default=512,
                        help="Hidden size of the FFN subnetworks.")
    parser.add_argument('-num_layers',
                        type=int,
                        default=8,
                        help="the number of layers of the pruned model")
    parser.add_argument('-emb_hidden_dim',
                        type=int,
                        default=128,
                        help="Hidden size of embedding factorization. \
                    Do not factorize embedding if value==-1")
    args = parser.parse_args()

    torch.manual_seed(0)

    args.model_path = os.path.join(args.model_path, args.task)
    args.output_dir = os.path.join(args.output_dir, args.task)

    print('Loading BERT from %s...' % args.model_path)
    model = PrunTinyBertForSequenceClassification.from_pretrained(
        args.model_path, num_labels=num_labels[args.task.lower()])
    config = model.config
    tokenizer = BertTokenizer.from_pretrained(args.model_path,
                                              do_lower_case=True)
    model.bert.encoder.layer = torch.nn.ModuleList(
        [model.bert.encoder.layer[i] for i in range(args.num_layers)])

    if args.ffn_hidden_dim>config.prun_intermediate_size or \
    (args.emb_hidden_dim>config.emb_hidden_dim and config.emb_hidden_dim!=-1):
        raise ValueError('Cannot prune the model to a larger size!')

    args.prun_ratio = args.ffn_hidden_dim / config.prun_intermediate_size
    print(
        'Pruning to %d heads, %d layers, %d FFN hidden dim, %d emb hidden dim...'
        % (args.keep_heads, args.num_layers, args.ffn_hidden_dim,
           args.emb_hidden_dim))
    importance_dir = os.path.join(args.model_path, 'taylor_score',
                                  'taylor.pkl')
    new_config = BertConfigPrun(num_attention_heads=args.keep_heads,
                                prun_hidden_size=int(args.keep_heads * 64),
                                prun_intermediate_size=args.ffn_hidden_dim,
                                num_hidden_layers=args.num_layers,
                                emb_hidden_dim=args.emb_hidden_dim)
    model = Taylor_pruning_structured(model, args.prun_ratio,
                                      config.num_attention_heads,
                                      args.keep_heads, importance_dir,
                                      args.emb_hidden_dim, new_config)

    output_dir = os.path.join(
        args.output_dir,
        'a%d_l%d_f%d_e%d' % (args.keep_heads, args.num_layers,
                             args.ffn_hidden_dim, args.emb_hidden_dim))

    print('Saving model to %s' % output_dir)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    torch.save(model.state_dict(), os.path.join(output_dir,
                                                'pytorch_model.bin'))
    new_config.save_pretrained(output_dir)
    tokenizer.save_vocabulary(output_dir)
    model = PrunTinyBertForSequenceClassification.from_pretrained(
        output_dir, num_labels=num_labels[args.task.lower()])
    torch.save(model.state_dict(), os.path.join(output_dir,
                                                'pytorch_model.bin'))
    print(
        "Number of parameters: %d" %
        sum([model.state_dict()[key].nelement()
             for key in model.state_dict()]))
    print(model.state_dict().keys())
                        type=int,
                        default=[1, 12])
    parser.add_argument('--intermediate_size_space',
                        nargs='+',
                        type=int,
                        default=[128, 3072])
    parser.add_argument('--mlm', action='store_true')

    parser.add_argument('--infer_cnt', type=int, default=10)

    args = parser.parse_args()

    config = BertConfig.from_pretrained(
        os.path.join(args.bert_model, 'config.json'))
    model = SuperTinyBertForPreTraining.from_scratch(args.bert_model, config)
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=True)

    device = 'cpu'
    model.to(device)
    model.eval()

    torch.set_num_threads(1)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # build arch space
    min_hidden_size, max_hidden_size = args.hidden_size_space
    min_ffn_size, max_ffn_size = args.intermediate_size_space
    min_qkv_size, max_qkv_size = args.qkv_size_space
Exemple #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--pretrain_model_name_or_path",
                        default=None,
                        type=str,
                        help="The pretrain model name or path.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--domain",
                        default='all',
                        type=str,
                        required=True,
                        help="The domain of given model.")
    parser.add_argument("--use_domain_loss",
                        default=False,
                        type=bool,
                        help="Whether to use domain loss.")
    parser.add_argument("--data_portion",
                        default=1.0,
                        type=float,
                        required=False,
                        help="How many data selected.")
    parser.add_argument("--domain_loss_weight",
                        default=0.2,
                        type=float,
                        help="The loss weight of domain.")
    parser.add_argument("--use_sample_weights",
                        default=False,
                        type=bool,
                        help="The loss weight of domain.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--weight_decay',
                        '--wd',
                        default=1e-4,
                        type=float,
                        metavar='W',
                        help='weight decay')
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )

    # added arguments
    parser.add_argument('--aug_train', action='store_true')
    parser.add_argument('--eval_step', type=int, default=50)
    parser.add_argument('--pred_distill', action='store_true')
    parser.add_argument('--data_url', type=str, default="")
    parser.add_argument('--temperature', type=float, default=1.)

    args = parser.parse_args()
    logger.info('The args: {}'.format(args))

    processors = {
        "mnli": MnliProcessor,
        "mnli-mm": MnliMismatchedProcessor,
        "senti": SentiProcessor
    }

    output_modes = {"mnli": "classification", "senti": "classification"}

    if args.task_name.lower() == "mnli":
        domain_idx_mapping = {
            domain: idx
            for idx, domain in enumerate(
                "telephone,government,slate,fiction,travel".split(","))
        }
    else:
        domain_idx_mapping = {
            domain: idx
            for idx, domain in enumerate("books,dvd,electronics,kitchen".split(
                ","))
        }
    num_domains = len(domain_idx_mapping)

    # intermediate distillation default parameters
    default_params = {
        "mnli": {
            "num_train_epochs": 5,
            "max_seq_length": 128
        },
        "senti": {
            "num_train_epochs": 5,
            "max_seq_length": 128
        },
    }

    acc_tasks = ["mnli", "mrpc", "sst-2", "qqp", "qnli", "rte", "senti"]
    corr_tasks = ["sts-b"]
    mcc_tasks = ["cola"]

    # Prepare devices
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    logger.info("device: {} n_gpu: {}".format(device, n_gpu))

    # Prepare seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # Prepare task settings
    # if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
    #     raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name in default_params:
        args.max_seq_len = default_params[task_name]["max_seq_length"]

    if not args.do_eval:
        if task_name in default_params:
            args.num_train_epoch = default_params[task_name][
                "num_train_epochs"]

    if task_name not in processors:
        raise ValueError("Task not found: %s" % task_name)

    processor = processors[task_name](portion=args.data_portion)
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.pretrain_model_name_or_path,
                                              do_lower_case=args.do_lower_case)

    if not args.do_eval:
        if not args.aug_train:
            train_examples = processor.get_train_examples(
                args.data_dir, args.domain)
        else:
            train_examples = processor.get_aug_examples(
                args.data_dir, args.domain)
        if args.gradient_accumulation_steps < 1:
            raise ValueError(
                "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
                .format(args.gradient_accumulation_steps))

        args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs

        portion_str = "_{}".format(
            args.data_portion) if args.data_portion != 1.0 else ""
        meta_str = "meta" if args.use_domain_loss or args.use_sample_weights else ""
        cached_train_path = os.path.join(
            args.data_dir, "cached_train_features_{}{}{}{}.pt".format(
                args.domain, meta_str,
                "_with_weights" if args.use_sample_weights else "",
                portion_str))
        if os.path.exists(cached_train_path):
            train_features = torch.load(cached_train_path)
        else:
            train_features = convert_examples_to_features(
                train_examples, label_list, args.max_seq_length, tokenizer,
                output_mode, domain_idx_mapping)
            torch.save(train_features, cached_train_path)
            print("Save to cached path %s" % cached_train_path)
        train_data, _ = get_tensor_data(output_mode, train_features)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

    if args.do_eval:
        eval_examples = processor.get_test_examples(args.data_dir, args.domain)
    else:
        eval_examples = processor.get_dev_examples(args.data_dir, args.domain)
    eval_features = convert_examples_to_features(eval_examples, label_list,
                                                 args.max_seq_length,
                                                 tokenizer, output_mode,
                                                 domain_idx_mapping)
    eval_data, eval_labels = get_tensor_data(output_mode, eval_features)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    meta_teacher_model = MetaTeacherForSequenceClassification.from_pretrained(
        args.pretrain_model_name_or_path,
        num_labels=num_labels,
        num_domains=num_domains)
    meta_teacher_model.to(device)

    if args.do_eval:
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)

        meta_teacher_model.eval()
        result = do_eval(meta_teacher_model, task_name, eval_dataloader,
                         device, output_mode, eval_labels, num_labels)
        logger.info("***** Eval results *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
    else:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        if n_gpu > 1:
            meta_teacher_model = torch.nn.DataParallel(meta_teacher_model)
        # Prepare optimizer
        param_optimizer = list(meta_teacher_model.named_parameters())
        size = 0
        for n, p in meta_teacher_model.named_parameters():
            logger.info('n: {}'.format(n))
            size += p.nelement()

        logger.info('Total parameters: {}'.format(size))
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        schedule = 'warmup_linear'
        optimizer = BertAdam(optimizer_grouped_parameters,
                             schedule=schedule,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

        # Train and evaluate
        global_step = 0
        best_dev_acc = 0.0
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        ce_loss_fn = CrossEntropyLoss(reduction="none")

        for epoch_ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0.
            tr_cls_loss = 0.

            meta_teacher_model.train()
            nb_tr_examples, nb_tr_steps = 0, 0

            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration", ascii=True)):
                batch = tuple(t.to(device) for t in batch)

                input_ids, input_mask, segment_ids, label_ids, seq_lengths, domain_ids, sample_weights = batch
                if input_ids.size()[0] != args.train_batch_size:
                    continue

                logits, domain_logits, *_ = meta_teacher_model(
                    input_ids, segment_ids, input_mask, domain_ids)
                losses = ce_loss_fn(logits, label_ids)

                if args.use_domain_loss:
                    shuffled_domain_ids = domain_ids[torch.randperm(
                        domain_ids.shape[0])]
                    domain_losses = ce_loss_fn(domain_logits,
                                               shuffled_domain_ids)
                    losses += args.domain_loss_weight * domain_losses
                if args.use_sample_weights:
                    loss = torch.mean(losses * sample_weights)
                else:
                    loss = torch.mean(losses)

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += label_ids.size(0)
                nb_tr_steps += 1

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if (global_step + 1) % args.eval_step == 0:
                    logger.info("***** Running evaluation *****")
                    logger.info("  Epoch = {} iter {} step".format(
                        epoch_, global_step))
                    logger.info("  Num examples = %d", len(eval_examples))
                    logger.info("  Batch size = %d", args.eval_batch_size)

                    meta_teacher_model.eval()

                    loss = tr_loss / (step + 1)
                    cls_loss = tr_cls_loss / (step + 1)

                    result = do_eval(meta_teacher_model, task_name,
                                     eval_dataloader, device, output_mode,
                                     eval_labels, num_labels)
                    result['global_step'] = global_step
                    result['cls_loss'] = cls_loss
                    result['loss'] = loss

                    result_to_file(result, output_eval_file)

                    save_model = False
                    if task_name in acc_tasks and result['acc'] > best_dev_acc:
                        best_dev_acc = result['acc']
                        save_model = True

                    if task_name in corr_tasks and result[
                            'corr'] > best_dev_acc:
                        best_dev_acc = result['corr']
                        save_model = True

                    if task_name in mcc_tasks and result['mcc'] > best_dev_acc:
                        best_dev_acc = result['mcc']
                        save_model = True

                    if save_model:
                        logger.info("***** Save model *****")

                        model_to_save = meta_teacher_model.module if hasattr(meta_teacher_model, 'module') \
                            else meta_teacher_model

                        model_name = WEIGHTS_NAME
                        output_model_file = os.path.join(
                            args.output_dir, model_name)
                        output_config_file = os.path.join(
                            args.output_dir, CONFIG_NAME)

                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        model_to_save.config.to_json_file(output_config_file)
                        tokenizer.save_vocabulary(args.output_dir)

                        if oncloud:
                            logging.info(
                                mox.file.list_directory(args.output_dir,
                                                        recursive=True))
                            logging.info(
                                mox.file.list_directory('.', recursive=True))
                            mox.file.copy_parallel(args.output_dir,
                                                   args.data_url)
                            mox.file.copy_parallel('.', args.data_url)

                    meta_teacher_model.train()
Exemple #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the .tsv files or the task.")
    parser.add_argument("--teacher_model",
                        default=None,
                        type=str,
                        help="The teacher model dir.")
    parser.add_argument("--student_model",
                        default=None,
                        type=str,
                        required=True,
                        help="The student model dir.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where model checkpoints will be written.")
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_len",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length ")
    parser.add_argument("--num_labels",
                        default=2,
                        type=int,
                        required=True,
                        help="")
    parser.add_argument("--task_mode",
                        default='classification',
                        type=str,
                        required=False,
                        help="task type")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run train on the train set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--weight_decay',
                        '--wd',
                        default=1e-4,
                        type=float,
                        metavar='W',
                        help='weight decay')
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate")

    # added arguments
    parser.add_argument('--aug_train', action='store_true')
    parser.add_argument('--eval_step', type=int, default=50)
    parser.add_argument('--pred_distill', action='store_true')
    parser.add_argument('--data_url', type=str, default="")
    parser.add_argument('--temperature', type=float, default=1.)

    args = parser.parse_args()
    logger.info('The args: {}'.format(args))

    # Prepare devices
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    logger.info("device: {} n_gpu: {}".format(device, n_gpu))

    # Prepare seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # Prepare task settings
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    num_labels = args.num_labels

    tokenizer = BertTokenizer.from_pretrained(args.student_model,
                                              do_lower_case=args.do_lower_case)

    if args.do_train:
        train_path = os.path.join(args.data_dir, 'train.txt')
        eval_path = os.path.join(args.data_dir, 'eval.txt')
        train_examples = read_examples(train_path)
        eval_examples = read_examples(eval_path)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        train_features = convert_examples_to_features(train_examples,
                                                      tokenizer,
                                                      args.max_seq_len)
        eval_features = convert_examples_to_features(eval_examples, tokenizer,
                                                     args.max_seq_len)

        train_features = MyDataLoader(train_features)
        eval_features = MyDataLoader(eval_features)

        train_dataloader = DataLoader(train_features,
                                      shuffle=True,
                                      batch_size=args.train_batch_size)
        # eval_dataloader = DataLoader(eval_features, shuffle=False, batch_size=args.eval_batch_size)

        teacher_model = TinyBertForSequenceClassification.from_pretrained(
            args.teacher_model, num_labels=num_labels)
        teacher_model.to(device)

    student_model = TinyBertForSequenceClassification.from_pretrained(
        args.student_model, num_labels=num_labels)
    student_model.to(device)
    # 只做预测
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        if n_gpu > 1:
            student_model = torch.nn.DataParallel(student_model)
            teacher_model = torch.nn.DataParallel(teacher_model)
        # Prepare optimizer
        param_optimizer = list(student_model.named_parameters())
        size = 0
        for n, p in student_model.named_parameters():
            logger.info('n: {}'.format(n))
            size += p.nelement()

        logger.info('Total parameters: {}'.format(size))
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        schedule = 'warmup_linear'
        if not args.pred_distill:
            schedule = 'none'
        optimizer = BertAdam(optimizer_grouped_parameters,
                             schedule=schedule,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)
        # Prepare loss functions
        loss_mse = MSELoss()

        def soft_cross_entropy(predicts, targets):
            student_likelihood = torch.nn.functional.log_softmax(predicts,
                                                                 dim=-1)
            targets_prob = torch.nn.functional.softmax(targets, dim=-1)
            return (-targets_prob * student_likelihood).mean()

        # Train and evaluate
        global_step = 0
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")

        for epoch_ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0.
            tr_att_loss = 0.
            tr_rep_loss = 0.
            tr_cls_loss = 0.

            student_model.train()
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration", ascii=True)):
                batch = tuple(t.to(device) for t in batch)

                input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch
                if input_ids.size()[0] != args.train_batch_size:
                    continue

                att_loss = 0.
                rep_loss = 0.
                cls_loss = 0.

                student_logits, student_atts, student_reps = student_model(
                    input_ids, segment_ids, input_mask, is_student=True)

                with torch.no_grad():
                    teacher_logits, teacher_atts, teacher_reps = teacher_model(
                        input_ids, segment_ids, input_mask)

                # 第一阶段
                if not args.pred_distill:
                    teacher_layer_num = len(teacher_atts)
                    student_layer_num = len(student_atts)
                    assert teacher_layer_num % student_layer_num == 0
                    layers_per_block = int(teacher_layer_num /
                                           student_layer_num)
                    new_teacher_atts = [
                        teacher_atts[i * layers_per_block + layers_per_block -
                                     1] for i in range(student_layer_num)
                    ]

                    for student_att, teacher_att in zip(
                            student_atts, new_teacher_atts):
                        student_att = torch.where(
                            student_att <= -1e2,
                            torch.zeros_like(student_att).to(device),
                            student_att)
                        teacher_att = torch.where(
                            teacher_att <= -1e2,
                            torch.zeros_like(teacher_att).to(device),
                            teacher_att)

                        tmp_loss = loss_mse(student_att, teacher_att)
                        att_loss += tmp_loss

                    new_teacher_reps = [
                        teacher_reps[i * layers_per_block]
                        for i in range(student_layer_num + 1)
                    ]
                    new_student_reps = student_reps
                    for student_rep, teacher_rep in zip(
                            new_student_reps, new_teacher_reps):
                        tmp_loss = loss_mse(student_rep, teacher_rep)
                        rep_loss += tmp_loss

                    loss = rep_loss + att_loss
                    tr_att_loss += att_loss.item()
                    tr_rep_loss += rep_loss.item()
                # 第二阶段
                else:
                    if args.task_mode == "classification":
                        cls_loss = soft_cross_entropy(
                            student_logits / args.temperature,
                            teacher_logits / args.temperature)
                    elif args.task_mode == "regression":
                        loss_mse = MSELoss()
                        cls_loss = loss_mse(student_logits.view(-1),
                                            label_ids.view(-1))
                    loss = cls_loss
                    tr_cls_loss += cls_loss.item()

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += label_ids.size(0)
                nb_tr_steps += 1

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if (global_step + 1) % args.eval_step == 0:
                    logger.info("***** Running evaluation *****")
                    logger.info("  Epoch = {} iter {} step".format(
                        epoch_, global_step))
                    logger.info("  Num examples = %d", len(eval_examples))
                    logger.info("  Batch size = %d", args.eval_batch_size)

                    student_model.eval()

                    loss = tr_loss / (step + 1)
                    cls_loss = tr_cls_loss / (step + 1)
                    att_loss = tr_att_loss / (step + 1)
                    rep_loss = tr_rep_loss / (step + 1)

                    result = {}
                    result['global_step'] = global_step
                    result['cls_loss'] = cls_loss
                    result['att_loss'] = att_loss
                    result['rep_loss'] = rep_loss
                    result['loss'] = loss

                    result_to_file(result, output_eval_file)

            logger.info("***** Save model *****")
            model_to_save = student_model.module if hasattr(
                student_model, 'module') else student_model
            model_name = f'{epoch_}_{WEIGHTS_NAME}'
            output_model_file = os.path.join(args.output_dir, model_name)
            output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

            torch.save(model_to_save.state_dict(), output_model_file)
            model_to_save.config.to_json_file(output_config_file)
            tokenizer.save_vocabulary(args.output_dir)
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--pregenerated_data",
                        type=Path,
                        required=True)
    parser.add_argument("--teacher_model",
                        default=None,
                        type=str,
                        required=True)
    parser.add_argument("--student_model",
                        default=None,
                        type=str,
                        required=True)
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True)

    # Other parameters
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")

    parser.add_argument("--reduce_memory",
                        action="store_true",
                        help="Store training data as on-disc memmaps to massively reduce memory usage")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--weight_decay',
                        '--wd',
                        default=1e-4,
                        type=float, metavar='W',
                        help='weight decay')
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--continue_train',
                        action='store_true',
                        help='Whether to train from checkpoints')

    # Additional arguments
    parser.add_argument('--eval_step',
                        type=int,
                        default=1000)

    # This is used for running on Huawei Cloud.
    parser.add_argument('--data_url',
                        type=str,
                        default="")

    args = parser.parse_args()
    logger.info('args:{}'.format(args))

    samples_per_epoch = []
    for i in range(int(args.num_train_epochs)):
        epoch_file = args.pregenerated_data / "epoch_{}.json".format(i)
        metrics_file = args.pregenerated_data / "epoch_{}_metrics.json".format(i)
        if epoch_file.is_file() and metrics_file.is_file():
            metrics = json.loads(metrics_file.read_text())
            samples_per_epoch.append(metrics['num_training_examples'])
        else:
            if i == 0:
                exit("No training data was found!")
            print("Warning! There are fewer epochs of pregenerated data ({}) than training epochs ({}).".format(i, args.num_train_epochs))
            print("This script will loop over the available data, but training diversity may be negatively impacted.")
            num_data_epochs = i
            break
    else:
        num_data_epochs = args.num_train_epochs

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.teacher_model, do_lower_case=args.do_lower_case)

    total_train_examples = 0
    for i in range(int(args.num_train_epochs)):
        # The modulo takes into account the fact that we may loop over limited epochs of data
        total_train_examples += samples_per_epoch[i % len(samples_per_epoch)]

    num_train_optimization_steps = int(
        total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

    if args.continue_train:
        student_model = TinyBertForPreTraining.from_pretrained(args.student_model)
    else:
        student_model = TinyBertForPreTraining.from_scratch(args.student_model)
    teacher_model = BertModel.from_pretrained(args.teacher_model)

    # student_model = TinyBertForPreTraining.from_scratch(args.student_model, fit_size=teacher_model.config.hidden_size)
    student_model.to(device)
    teacher_model.to(device)

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        teacher_model = DDP(teacher_model)
    elif n_gpu > 1:
        student_model = torch.nn.DataParallel(student_model)
        teacher_model = torch.nn.DataParallel(teacher_model)

    size = 0
    for n, p in student_model.named_parameters():
        logger.info('n: {}'.format(n))
        logger.info('p: {}'.format(p.nelement()))
        size += p.nelement()

    logger.info('Total parameters: {}'.format(size))

    # Prepare optimizer
    param_optimizer = list(student_model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    loss_mse = MSELoss()
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=num_train_optimization_steps)

    global_step = 0
    logging.info("***** Running training *****")
    logging.info("  Num examples = {}".format(total_train_examples))
    logging.info("  Batch size = %d", args.train_batch_size)
    logging.info("  Num steps = %d", num_train_optimization_steps)

    for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
        epoch_dataset = PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer,
                                            num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory)
        if args.local_rank == -1:
            train_sampler = RandomSampler(epoch_dataset)
        else:
            train_sampler = DistributedSampler(epoch_dataset)
        train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

        tr_loss = 0.
        tr_att_loss = 0.
        tr_rep_loss = 0.
        student_model.train()
        nb_tr_examples, nb_tr_steps = 0, 0
        with tqdm(total=len(train_dataloader), desc="Epoch {}".format(epoch)) as pbar:
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", ascii=True)):
                batch = tuple(t.to(device) for t in batch)

                input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
                if input_ids.size()[0] != args.train_batch_size:
                    continue

                att_loss = 0.
                rep_loss = 0.

                student_atts, student_reps = student_model(input_ids, segment_ids, input_mask)
                teacher_reps, teacher_atts, _ = teacher_model(input_ids, segment_ids, input_mask)
                teacher_reps = [teacher_rep.detach() for teacher_rep in teacher_reps]  # speedup 1.5x
                teacher_atts = [teacher_att.detach() for teacher_att in teacher_atts]

                teacher_layer_num = len(teacher_atts)
                student_layer_num = len(student_atts)
                assert teacher_layer_num % student_layer_num == 0
                layers_per_block = int(teacher_layer_num / student_layer_num)
                new_teacher_atts = [teacher_atts[i * layers_per_block + layers_per_block - 1]
                                    for i in range(student_layer_num)]

                for student_att, teacher_att in zip(student_atts, new_teacher_atts):
                    student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att).to(device),
                                              student_att)
                    teacher_att = torch.where(teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device),
                                              teacher_att)
                    att_loss += loss_mse(student_att, teacher_att)

                new_teacher_reps = [teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1)]
                new_student_reps = student_reps

                for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps):
                    rep_loss += loss_mse(student_rep, teacher_rep)

                loss = att_loss + rep_loss

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_att_loss += att_loss.item()
                tr_rep_loss += rep_loss.item()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                pbar.update(1)

                mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
                mean_att_loss = tr_att_loss * args.gradient_accumulation_steps / nb_tr_steps
                mean_rep_loss = tr_rep_loss * args.gradient_accumulation_steps / nb_tr_steps

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                    if (global_step + 1) % args.eval_step == 0:
                        result = {}
                        result['global_step'] = global_step
                        result['loss'] = mean_loss
                        result['att_loss'] = mean_att_loss
                        result['rep_loss'] = mean_rep_loss
                        output_eval_file = os.path.join(args.output_dir, "log.txt")
                        with open(output_eval_file, "a") as writer:
                            logger.info("***** Eval results *****")
                            for key in sorted(result.keys()):
                                logger.info("  %s = %s", key, str(result[key]))
                                writer.write("%s = %s\n" % (key, str(result[key])))

                        # Save a trained model
                        model_name = "step_{}_{}".format(global_step, WEIGHTS_NAME)
                        logging.info("** ** * Saving fine-tuned model ** ** * ")
                        # Only save the model it-self
                        model_to_save = student_model.module if hasattr(student_model, 'module') else student_model

                        output_model_file = os.path.join(args.output_dir, model_name)
                        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

                        torch.save(model_to_save.state_dict(), output_model_file)
                        model_to_save.config.to_json_file(output_config_file)
                        tokenizer.save_vocabulary(args.output_dir)

                        if oncloud:
                            logging.info(mox.file.list_directory(args.output_dir, recursive=True))
                            logging.info(mox.file.list_directory('.', recursive=True))
                            mox.file.copy_parallel(args.output_dir, args.data_url)
                            mox.file.copy_parallel('.', args.data_url)

            model_name = "step_{}_{}".format(global_step, WEIGHTS_NAME)
            logging.info("** ** * Saving fine-tuned model ** ** * ")
            model_to_save = student_model.module if hasattr(student_model, 'module') else student_model

            output_model_file = os.path.join(args.output_dir, model_name)
            output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

            torch.save(model_to_save.state_dict(), output_model_file)
            model_to_save.config.to_json_file(output_config_file)
            tokenizer.save_vocabulary(args.output_dir)

            if oncloud:
                logging.info(mox.file.list_directory(args.output_dir, recursive=True))
                logging.info(mox.file.list_directory('.', recursive=True))
                mox.file.copy_parallel(args.output_dir, args.data_url)
                mox.file.copy_parallel('.', args.data_url)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--teacher_model",
                        default=None,
                        type=str,
                        help="The teacher model dir.")
    parser.add_argument("--student_model",
                        default=None,
                        type=str,
                        required=True,
                        help="The student model dir.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--weight_decay',
                        '--wd',
                        default=1e-4,
                        type=float,
                        metavar='W',
                        help='weight decay')
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )

    # added arguments
    parser.add_argument('--aug_train', action='store_true')
    parser.add_argument('--eval_step', type=int, default=50)
    parser.add_argument('--pred_distill', action='store_true')
    parser.add_argument('--data_url', type=str, default="")
    parser.add_argument('--temperature', type=float, default=1.)
    parser.add_argument('--local_rank', type=int, default=-1)

    args = parser.parse_args()
    logger.info('The args: {}'.format(args))

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mnli-mm": MnliMismatchedProcessor,
        "mrpc": MrpcProcessor,
        "sst-2": Sst2Processor,
        "sts-b": StsbProcessor,
        "qqp": QqpProcessor,
        "qnli": QnliProcessor,
        "rte": RteProcessor,
        "wnli": WnliProcessor
    }

    output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification",
        "wnli": "classification"
    }

    # intermediate distillation default parameters
    default_params = {
        "cola": {
            "num_train_epochs": 50,
            "max_seq_length": 64
        },
        "mnli": {
            "num_train_epochs": 5,
            "max_seq_length": 128
        },
        "mrpc": {
            "num_train_epochs": 20,
            "max_seq_length": 128
        },
        "sst-2": {
            "num_train_epochs": 10,
            "max_seq_length": 64
        },
        "sts-b": {
            "num_train_epochs": 20,
            "max_seq_length": 128
        },
        "qqp": {
            "num_train_epochs": 5,
            "max_seq_length": 128
        },
        "qnli": {
            "num_train_epochs": 10,
            "max_seq_length": 128
        },
        "rte": {
            "num_train_epochs": 20,
            "max_seq_length": 128
        }
    }

    acc_tasks = ["mnli", "mrpc", "sst-2", "qqp", "qnli", "rte"]
    corr_tasks = ["sts-b"]
    mcc_tasks = ["cola"]

    # Prepare devices

    n_gpu = torch.cuda.device_count()

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    logger.info("device: {} n_gpu: {}".format(device, n_gpu))

    # Prepare seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # Prepare task settings
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name in default_params:
        args.max_seq_len = default_params[task_name]["max_seq_length"]

    if not args.pred_distill and not args.do_eval:
        if task_name in default_params:
            args.num_train_epoch = default_params[task_name][
                "num_train_epochs"]

    if task_name not in processors:
        raise ValueError("Task not found: %s" % task_name)

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.student_model,
                                              do_lower_case=args.do_lower_case)

    if not args.do_eval:
        #if not args.aug_train:
        #    train_examples = processor.get_train_examples(args.data_dir)
        #else:
        #    train_examples = processor.get_aug_examples(args.data_dir)
        if args.gradient_accumulation_steps < 1:
            raise ValueError(
                "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
                .format(args.gradient_accumulation_steps))

        args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

        # rewrite data processing here
        assert args.task_name == "MNLI", "the script is designed for MNLI only now"
        mnli_datasets = load_dataset("text",
                                     data_files=os.path.join(
                                         args.data_dir, "train_aug.tsv"))
        label_classes = processor.get_labels()
        label_map = {label: i for i, label in enumerate(label_classes)}

        def preprocess_func(examples, max_seq_length=args.max_seq_length):
            splits = [e.split('\t') for e in examples['text']]
            # tokenize for sent1 & sent2
            tokens_s1 = [tokenizer.tokenize(e[8]) for e in splits]
            tokens_s2 = [tokenizer.tokenize(e[9]) for e in splits]
            for t1, t2 in zip(tokens_s1, tokens_s2):
                truncate_seq_pair(t1, t2, max_length=max_seq_length - 3)
            input_ids_list = []
            input_mask_list = []
            segment_ids_list = []
            seq_length_list = []
            labels_list = []
            labels = [e[-1] for e in splits]
            # print(len(labels))
            for token_a, token_b, l in zip(
                    tokens_s1, tokens_s2,
                    labels):  # zip(tokens_as, tokens_bs):
                tokens = ["[CLS]"] + token_a + ["[SEP]"]
                segment_ids = [0] * len(tokens)
                tokens += token_b + ["[SEP]"]
                segment_ids += [1] * (len(token_b) + 1)
                input_ids = tokenizer.convert_tokens_to_ids(
                    tokens)  # tokenize to id
                input_mask = [1] * len(input_ids)
                seq_length = len(input_ids)
                padding = [0] * (max_seq_length - len(input_ids))
                input_ids += padding
                input_mask += padding
                segment_ids += padding
                assert len(input_ids) == max_seq_length
                assert len(input_mask) == max_seq_length
                assert len(segment_ids) == max_seq_length
                input_ids_list.append(input_ids)
                input_mask_list.append(input_mask)
                segment_ids_list.append(segment_ids)
                seq_length_list.append(seq_length)
                labels_list.append(label_map[l])

            results = {
                "input_ids": input_ids_list,
                "input_mask": input_mask_list,
                "segment_ids": segment_ids_list,
                "seq_length": seq_length_list,
                "label_ids": labels_list
            }

            return results

        mnli_datasets = mnli_datasets.map(preprocess_func, batched=True)

        # train_features = convert_examples_to_features(train_examples, label_list,
        #                                               args.max_seq_length, tokenizer, output_mode, logger)
        train_data = mnli_datasets['train'].remove_columns('text')

        print(train_data[0])
        # train_data, _ = get_tensor_data(output_mode, train_features)
        num_train_optimization_steps = int(
            len(train_data) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        logger.info("Initializing Distributed Environment")
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend="nccl")
        train_sampler = torch.utils.data.DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

    eval_examples = processor.get_dev_examples(args.data_dir)
    eval_features = convert_examples_to_features(eval_examples, label_list,
                                                 args.max_seq_length,
                                                 tokenizer, output_mode,
                                                 logger)
    eval_data, eval_labels = get_tensor_data(output_mode, eval_features)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    # DDP setting
    local_rank = args.local_rank
    torch.cuda.set_device(local_rank)

    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    student_model = TinyBertForSequenceClassification.from_pretrained(
        args.student_model, num_labels=num_labels).to(device)

    if args.do_eval:
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)

        student_model.eval()
        result = do_eval(student_model, task_name, eval_dataloader, device,
                         output_mode, eval_labels, num_labels)
        logger.info("***** Eval results *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
    else:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_data))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        teacher_model = TinyBertForSequenceClassification.from_pretrained(
            args.teacher_model, num_labels=num_labels).to(device)
        student_model = DDP(student_model,
                            device_ids=[local_rank],
                            output_device=local_rank)
        teacher_model = DDP(teacher_model,
                            device_ids=[local_rank],
                            output_device=local_rank)
        # Prepare optimizer
        param_optimizer = list(student_model.named_parameters())
        size = 0
        for n, p in student_model.named_parameters():
            logger.info('n: {}'.format(n))
            size += p.nelement()

        logger.info('Total parameters: {}'.format(size))
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        schedule = 'warmup_linear'
        if not args.pred_distill:
            schedule = 'none'
        optimizer = BertAdam(optimizer_grouped_parameters,
                             schedule=schedule,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)
        scaler = torch.cuda.amp.GradScaler()

        # Prepare loss functions
        loss_mse = MSELoss()

        def soft_cross_entropy(predicts, targets):
            student_likelihood = torch.nn.functional.log_softmax(predicts,
                                                                 dim=-1)
            targets_prob = torch.nn.functional.softmax(targets, dim=-1)
            return (-targets_prob * student_likelihood).mean()

        # Train and evaluate
        global_step = 0
        best_dev_acc = 0.0
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")

        for epoch_ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0.
            tr_att_loss = 0.
            tr_rep_loss = 0.
            tr_cls_loss = 0.

            student_model.train()
            nb_tr_examples, nb_tr_steps = 0, 0

            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration", ascii=True)):
                # optimizer.zero_grad()
                #batch = tuple(torch.tensor(t, dtype=torch.long).to(device) for t in batch)
                # print(batch)
                inputs = {}
                for k, v in batch.items():
                    if isinstance(v, torch.Tensor):
                        inputs[k] = v.to(device)
                    elif isinstance(v, List):
                        inputs[k] = torch.stack(v, dim=1).to(device)

                # inputs = {k: torch.tensor(v, dtype=torch.long).to(device) for k, v in batch.items()}
                # input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch
                # print([(k, inputs[k].size()) for k in inputs])
                if inputs['input_ids'].size()[0] != args.train_batch_size:
                    continue

                att_loss = 0.
                rep_loss = 0.
                cls_loss = 0.
                with autocast():
                    student_logits, student_atts, student_reps = student_model(
                        inputs['input_ids'],
                        inputs['segment_ids'],
                        inputs['input_mask'],
                        is_student=True)
                    with torch.no_grad():
                        teacher_logits, teacher_atts, teacher_reps = teacher_model(
                            inputs['input_ids'], inputs['segment_ids'],
                            inputs['input_mask'])

                    if not args.pred_distill:
                        teacher_layer_num = len(teacher_atts)
                        student_layer_num = len(student_atts)
                        assert teacher_layer_num % student_layer_num == 0
                        layers_per_block = int(teacher_layer_num /
                                               student_layer_num)
                        new_teacher_atts = [
                            teacher_atts[i * layers_per_block +
                                         layers_per_block - 1]
                            for i in range(student_layer_num)
                        ]

                        for student_att, teacher_att in zip(
                                student_atts, new_teacher_atts):
                            student_att = torch.where(
                                student_att <= -1e2,
                                torch.zeros_like(student_att).to(device),
                                student_att)
                            teacher_att = torch.where(
                                teacher_att <= -1e2,
                                torch.zeros_like(teacher_att).to(device),
                                teacher_att)

                            tmp_loss = loss_mse(student_att, teacher_att)
                            att_loss += tmp_loss

                        new_teacher_reps = [
                            teacher_reps[i * layers_per_block]
                            for i in range(student_layer_num + 1)
                        ]
                        new_student_reps = student_reps
                        for student_rep, teacher_rep in zip(
                                new_student_reps, new_teacher_reps):
                            tmp_loss = loss_mse(student_rep, teacher_rep)
                            rep_loss += tmp_loss
                        # add this term for amp detection
                        loss = rep_loss + att_loss + 0 * soft_cross_entropy(
                            student_logits / args.temperature,
                            teacher_logits / args.temperature)
                        tr_att_loss += att_loss.item()
                        tr_rep_loss += rep_loss.item()
                    else:
                        if output_mode == "classification":
                            cls_loss = soft_cross_entropy(
                                student_logits / args.temperature,
                                teacher_logits / args.temperature)
                        elif output_mode == "regression":
                            loss_mse = MSELoss()
                            cls_loss = loss_mse(student_logits.view(-1),
                                                label_ids.view(-1))

                        loss = cls_loss + 0 * loss_mse(
                            student_atts[0], teacher_atts[0]) + 0 * loss_mse(
                                teacher_reps[0], student_reps[0])
                        tr_cls_loss += cls_loss.item()

                # if n_gpu > 1:
                #     loss = loss.mean()  # mean() to average on multi-gpu.
                # if args.gradient_accumulation_steps > 1:
                #     loss = loss / args.gradient_accumulation_steps
                scaler.scale(loss).backward()
                # loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += inputs['label_ids'].size(0)
                nb_tr_steps += 1

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # optimizer.step()
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    global_step += 1

                if (global_step +
                        1) % args.eval_step == 0 and args.local_rank == 0:
                    logger.info("***** Running evaluation *****")
                    logger.info("  Epoch = {} iter {} step".format(
                        epoch_, global_step))
                    logger.info("  Num examples = %d", len(eval_examples))
                    logger.info("  Batch size = %d", args.eval_batch_size)

                    student_model.eval()

                    loss = tr_loss / (step + 1)
                    cls_loss = tr_cls_loss / (step + 1)
                    att_loss = tr_att_loss / (step + 1)
                    rep_loss = tr_rep_loss / (step + 1)

                    result = {}
                    if args.pred_distill:
                        result = do_eval(student_model, task_name,
                                         eval_dataloader, device, output_mode,
                                         eval_labels, num_labels)
                    result['global_step'] = global_step
                    result['cls_loss'] = cls_loss
                    result['att_loss'] = att_loss
                    result['rep_loss'] = rep_loss
                    result['loss'] = loss

                    result_to_file(result, output_eval_file)

                    if not args.pred_distill:
                        save_model = True
                    else:
                        save_model = False

                        if task_name in acc_tasks and result[
                                'acc'] > best_dev_acc:
                            best_dev_acc = result['acc']
                            save_model = True

                        if task_name in corr_tasks and result[
                                'corr'] > best_dev_acc:
                            best_dev_acc = result['corr']
                            save_model = True

                        if task_name in mcc_tasks and result[
                                'mcc'] > best_dev_acc:
                            best_dev_acc = result['mcc']
                            save_model = True

                    if save_model and args.local_rank == 0:
                        logger.info("***** Save model *****")

                        model_to_save = student_model.module if hasattr(
                            student_model, 'module') else student_model

                        model_name = WEIGHTS_NAME
                        # if not args.pred_distill:
                        #     model_name = "step_{}_{}".format(global_step, WEIGHTS_NAME)
                        output_model_file = os.path.join(
                            args.output_dir, model_name)
                        output_config_file = os.path.join(
                            args.output_dir, CONFIG_NAME)

                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        model_to_save.config.to_json_file(output_config_file)
                        tokenizer.save_vocabulary(args.output_dir)

                        # Test mnli-mm
                        if args.pred_distill and task_name == "mnli":
                            task_name = "mnli-mm"
                            processor = processors[task_name]()
                            if not os.path.exists(args.output_dir + '-MM'):
                                os.makedirs(args.output_dir + '-MM')

                            eval_examples = processor.get_dev_examples(
                                args.data_dir)

                            eval_features = convert_examples_to_features(
                                eval_examples, label_list, args.max_seq_length,
                                tokenizer, output_mode, logger)
                            eval_data, eval_labels = get_tensor_data(
                                output_mode, eval_features)

                            logger.info("***** Running mm evaluation *****")
                            logger.info("  Num examples = %d",
                                        len(eval_examples))
                            logger.info("  Batch size = %d",
                                        args.eval_batch_size)

                            eval_sampler = SequentialSampler(eval_data)
                            eval_dataloader = DataLoader(
                                eval_data,
                                sampler=eval_sampler,
                                batch_size=args.eval_batch_size)

                            result = do_eval(student_model, task_name,
                                             eval_dataloader, device,
                                             output_mode, eval_labels,
                                             num_labels)

                            result['global_step'] = global_step

                            tmp_output_eval_file = os.path.join(
                                args.output_dir + '-MM', "eval_results.txt")
                            result_to_file(result, tmp_output_eval_file)

                            task_name = 'mnli'

                        if oncloud:
                            logging.info(
                                mox.file.list_directory(args.output_dir,
                                                        recursive=True))
                            logging.info(
                                mox.file.list_directory('.', recursive=True))
                            mox.file.copy_parallel(args.output_dir,
                                                   args.data_url)
                            mox.file.copy_parallel('.', args.data_url)

                    student_model.train()
def main():
    parser = ArgumentParser()
    parser.add_argument('--train_corpus', type=Path, required=True)
    parser.add_argument("--output_dir", type=Path, required=True)
    parser.add_argument("--bert_model", type=str, required=True)
    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument(
        "--do_whole_word_mask",
        action="store_true",
        help=
        "Whether to use whole word masking rather than per-WordPiece masking.")
    parser.add_argument(
        "--reduce_memory",
        action="store_true",
        help=
        "Reduce memory usage for large datasets by keeping data on disc rather than in memory"
    )

    parser.add_argument("--num_workers",
                        type=int,
                        default=1,
                        help="The number of workers to use to write the files")
    parser.add_argument("--epochs_to_generate",
                        type=int,
                        default=3,
                        help="Number of epochs of data to pregenerate")
    parser.add_argument("--max_seq_len", type=int, default=128)
    parser.add_argument(
        "--short_seq_prob",
        type=float,
        default=0.1,
        help="Probability of making a short sentence as a training example")
    parser.add_argument(
        "--masked_lm_prob",
        type=float,
        default=0.0,
        help="Probability of masking each token for the LM task"
    )  # no [mask] symbol in corpus
    parser.add_argument(
        "--max_predictions_per_seq",
        type=int,
        default=20,
        help="Maximum number of tokens to mask in each sequence")
    parser.add_argument('--data_url', type=str, default="")
    parser.add_argument('--one_seq', action='store_true')

    args = parser.parse_args()

    if args.num_workers > 1 and args.reduce_memory:
        raise ValueError("Cannot use multiple workers while reducing memory")

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    vocab_list = list(tokenizer.vocab.keys())
    doc_num = 0
    with DocumentDatabase(reduce_memory=args.reduce_memory) as docs:
        import os
        for root, dirs, files in os.walk(args.train_corpus, topdown=False):
            for name in files:
                logger.info(f'Start on {Path(root, name)}')
                with Path(root, name).open() as f:
                    doc = []
                    for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
                        line = line.strip()
                        if line == "":
                            docs.add_document(doc)
                            doc = []
                            doc_num += 1
                            if doc_num % 100 == 0:
                                logger.info('loaded {} docs!'.format(doc_num))
                        else:
                            tokens = tokenizer.tokenize(line)
                            doc.append(tokens)
                    if doc:
                        docs.add_document(
                            doc
                        )  # If the last doc didn't end on a newline, make sure it still gets added
        if len(docs) <= 1:
            exit(
                "ERROR: No document breaks were found in the input file! These are necessary to allow the script to "
                "ensure that random NextSentences are not sampled from the same document. Please add blank lines to "
                "indicate breaks between documents in your input file. If your dataset does not contain multiple "
                "documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, "
                "sections or paragraphs.")

        args.output_dir.mkdir(exist_ok=True)

        if args.num_workers > 1:
            writer_workers = Pool(
                min(args.num_workers, args.epochs_to_generate))
            arguments = [(docs, vocab_list, args, idx)
                         for idx in range(args.epochs_to_generate)]
            writer_workers.starmap(create_training_file, arguments)
        else:
            for epoch in trange(args.epochs_to_generate, desc="Epoch"):
                bi_text = True if not args.one_seq else False
                epoch_file, metric_file = create_training_file(docs,
                                                               vocab_list,
                                                               args,
                                                               epoch,
                                                               bi_text=bi_text)

                if oncloud:
                    logging.info(
                        mox.file.list_directory(str(args.output_dir),
                                                recursive=True))
                    logging.info(mox.file.list_directory('.', recursive=True))
                    mox.file.copy_parallel(str(args.output_dir), args.data_url)
                    mox.file.copy_parallel('.', args.data_url)

                    os.remove(str(epoch_file))
                    os.remove(str(metric_file))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--teacher_model",
                        default=None,
                        type=str,
                        help="The teacher model dir.")
    parser.add_argument("--student_model",
                        default=None,
                        type=str,
                        required=True,
                        help="The student model dir.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--weight_decay', '--wd',
                        default=1e-4,
                        type=float,
                        metavar='W',
                        help='weight decay')
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        default=False,
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")

    # added arguments
    parser.add_argument('--aug_train',
                        action='store_true')
    parser.add_argument('--eval_step',
                        type=int,
                        default=50)
    parser.add_argument('--pred_distill',
                        action='store_true')
    parser.add_argument('--data_url',
                        type=str,
                        default="")
    parser.add_argument('--temperature',
                        type=float,
                        default=1.)

    args = parser.parse_args()
    logger.info('The args: {}'.format(args))
    wandb.config.update(args)

    processors = {
        "race": RaceProcessor,
    }

    # intermediate distillation default parameters
    default_params = {
        "race": {"num_train_epochs": 3, "max_seq_length": 80},
    }

    # Prepare devices
    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()

    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO)

    logger.info("device: {} n_gpu: {}".format(device, n_gpu))

    # Prepare seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # Prepare task settings
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name in default_params:
        args.max_seq_len = default_params[task_name]["max_seq_length"]

    if not args.pred_distill and not args.do_eval:
        if task_name in default_params:
            args.num_train_epoch = default_params[task_name]["num_train_epochs"]

    if task_name not in processors:
        raise ValueError("Task not found: %s" % task_name)

    processor = processors[task_name]()
    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.student_model, do_lower_case=args.do_lower_case)

    if not args.do_eval:
        if not args.aug_train:
            train_examples = processor.get_train_examples(args.data_dir)
        else:
            train_examples = processor.get_aug_examples(args.data_dir)
        if args.gradient_accumulation_steps < 1:
            raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                args.gradient_accumulation_steps))

        args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs

        cached_features_file_train = os.path.join(
            args.data_dir,
            "cached_train_{}_{}_{}_tinybert".format(tokenizer.__class__.__name__, str(args.max_seq_length),
                                                    task_name, ),
        )

        if os.path.exists(cached_features_file_train):
            train_features = torch.load(cached_features_file_train)
        else:
            train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer)
            torch.save(train_features, cached_features_file_train)

        train_data, _ = get_tensor_data(train_features)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

    eval_examples = processor.get_dev_examples(args.data_dir)

    cached_features_file_eval = os.path.join(
        args.data_dir,
        "cached_dev_{}_{}_{}_tinybert".format(tokenizer.__class__.__name__, str(args.max_seq_length), task_name, ),
    )

    if os.path.exists(cached_features_file_eval):
        eval_features = torch.load(cached_features_file_eval)
    else:
        eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer)
        torch.save(eval_features, cached_features_file_eval)

    eval_data, eval_labels = get_tensor_data(eval_features)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

    if not args.do_eval:
        teacher_model = TinyBertForMultipleChoice.from_pretrained(args.teacher_model)
        teacher_model.to(device)

    student_model = TinyBertForMultipleChoice.from_pretrained(args.student_model)
    student_model.to(device)
    wandb.watch(student_model, log='all')

    if args.do_eval:
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)

        student_model.eval()
        result = do_eval(student_model, task_name, eval_dataloader,
                         device, eval_labels, num_labels)
        logger.info("***** Eval results *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            wandb.log(result)
    else:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        # Prepare optimizer
        param_optimizer = list(student_model.named_parameters())
        size = 0
        for n, p in student_model.named_parameters():
            logger.info('n: {}'.format(n))
            size += p.nelement()

        logger.info('Total parameters: {}'.format(size))
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        schedule = 'warmup_linear'
        if not args.pred_distill:
            schedule = 'none'

        optimizer = BertAdam(optimizer_grouped_parameters,
                             schedule=schedule,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

        if args.fp16:
            if not _has_apex:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            student_model, optimizer = amp.initialize(student_model, optimizer, opt_level='O1')

        if n_gpu > 1:
            student_model = torch.nn.DataParallel(student_model)
            teacher_model = torch.nn.DataParallel(teacher_model)

        # Prepare loss functions
        loss_mse = MSELoss()

        def soft_cross_entropy(predicts, targets):
            student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1)
            targets_prob = torch.nn.functional.softmax(targets, dim=-1)
            return (- targets_prob * student_likelihood).mean()

        # Train and evaluate
        global_step = 0
        best_dev_acc = 0.0
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")

        for epoch_ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0.
            tr_att_loss = 0.
            tr_rep_loss = 0.
            tr_cls_loss = 0.

            student_model.train()
            nb_tr_examples, nb_tr_steps = 0, 0

            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", ascii=True)):
                batch = tuple(t.to(device) for t in batch)

                input_ids, input_mask, segment_ids, label_ids = batch
                if input_ids.size()[0] != args.train_batch_size:
                    continue

                att_loss = 0.
                rep_loss = 0.
                cls_loss = 0.

                student_logits, student_atts, student_reps = student_model(input_ids=input_ids,
                                                                           token_type_ids=segment_ids,
                                                                           attention_mask=input_mask,
                                                                           is_student=True)

                with torch.no_grad():
                    teacher_logits, teacher_atts, teacher_reps = teacher_model(input_ids=input_ids,
                                                                               token_type_ids=segment_ids,
                                                                               attention_mask=input_mask)

                if not args.pred_distill:
                    teacher_layer_num = len(teacher_atts)
                    student_layer_num = len(student_atts)
                    assert teacher_layer_num % student_layer_num == 0
                    layers_per_block = int(teacher_layer_num / student_layer_num)
                    new_teacher_atts = [teacher_atts[i * layers_per_block + layers_per_block - 1]
                                        for i in range(student_layer_num)]

                    for student_att, teacher_att in zip(student_atts, new_teacher_atts):
                        student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att).to(device),
                                                  student_att)
                        teacher_att = torch.where(teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device),
                                                  teacher_att)

                        tmp_loss = loss_mse(student_att, teacher_att)
                        att_loss += tmp_loss

                    new_teacher_reps = [teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1)]
                    new_student_reps = student_reps
                    for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps):
                        tmp_loss = loss_mse(student_rep, teacher_rep)
                        rep_loss += tmp_loss

                    loss = rep_loss + att_loss
                    tr_att_loss += att_loss.item()
                    tr_rep_loss += rep_loss.item()
                else:
                    cls_loss = soft_cross_entropy(student_logits / args.temperature, teacher_logits / args.temperature)

                    loss = cls_loss
                    tr_cls_loss += cls_loss.item()

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += label_ids.size(0)
                nb_tr_steps += 1

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)

                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if (global_step + 1) % args.eval_step == 0:
                    logger.info("***** Running evaluation *****")
                    logger.info("  Epoch = {} iter {} step".format(epoch_, global_step))
                    logger.info("  Num examples = %d", len(eval_examples))
                    logger.info("  Batch size = %d", args.eval_batch_size)

                    student_model.eval()

                    loss = tr_loss / (step + 1)
                    cls_loss = tr_cls_loss / (step + 1)
                    att_loss = tr_att_loss / (step + 1)
                    rep_loss = tr_rep_loss / (step + 1)

                    result = {}
                    if args.pred_distill:
                        result = do_eval(student_model, task_name, eval_dataloader,
                                         device, eval_labels, num_labels)
                    result['global_step'] = global_step
                    result['cls_loss'] = cls_loss
                    result['att_loss'] = att_loss
                    result['rep_loss'] = rep_loss
                    result['loss'] = loss

                    wandb.log(result, step=global_step)

                    result_to_file(result, output_eval_file)

                    if not args.pred_distill:
                        save_model = True
                    else:
                        save_model = False

                        if result['acc'] > best_dev_acc:
                            best_dev_acc = result['acc']
                            save_model = True

                    if save_model:
                        logger.info("***** Save model *****")

                        model_to_save = student_model.module if hasattr(student_model, 'module') else student_model

                        model_name = WEIGHTS_NAME
                        # if not args.pred_distill:
                        #     model_name = "step_{}_{}".format(global_step, WEIGHTS_NAME)
                        output_model_file = os.path.join(args.output_dir, model_name)
                        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

                        torch.save(model_to_save.state_dict(), output_model_file)
                        model_to_save.config.to_json_file(output_config_file)
                        tokenizer.save_vocabulary(args.output_dir)

                        if oncloud:
                            logging.info(mox.file.list_directory(args.output_dir, recursive=True))
                            logging.info(mox.file.list_directory('.', recursive=True))
                            mox.file.copy_parallel(args.output_dir, args.data_url)
                            mox.file.copy_parallel('.', args.data_url)

                    student_model.train()
Exemple #9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--model",
                        default=None,
                        type=str,
                        required=True,
                        help="The model dir.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_eval",
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--weight_decay',
                        '--wd',
                        default=1e-4,
                        type=float,
                        metavar='W',
                        help='weight decay')
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.06,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )

    parser.add_argument('--weight_bit',
                        type=int,
                        default=4,
                        help="Number of bits for weight.")
    parser.add_argument('--quant_group_number',
                        type=int,
                        default=1,
                        help="Number of bits for weight.")
    parser.add_argument('--activation_bit',
                        type=int,
                        default=8,
                        help="Number of bits for weight.")
    # added arguments
    parser.add_argument('--aug_train',
                        type=str,
                        default='none',
                        help="Whether to train with augmented data.")
    parser.add_argument('--eval_step', type=int, default=50)
    parser.add_argument('--data_url', type=str, default="")
    parser.add_argument('--temperature', type=float, default=1.)
    parser.add_argument('--train_name', type=str, default="")
    parser.add_argument('--val_name', type=str, default="")

    args = parser.parse_args()
    logger.info('The args: {}'.format(args))

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mnli-mm": MnliMismatchedProcessor,
        "mrpc": MrpcProcessor,
        "sst-2": Sst2Processor,
        "sts-b": StsbProcessor,
        "qqp": QqpProcessor,
        "qnli": QnliProcessor,
        "rte": RteProcessor,
        "wnli": WnliProcessor
    }

    output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        #"sst-2": "regression",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification",
        "wnli": "classification"
    }

    # intermediate distillation default parameters
    default_params = {
        "cola": {
            "num_train_epochs": 10,
            "max_seq_length": 64,
            "learning_rate": 2e-5,
            "train_batch_size": 32
        },
        "sst-2": {
            "num_train_epochs": 10,
            "max_seq_length": 64,
            "learning_rate": 2e-5,
            "train_batch_size": 32
        },
        "mnli": {
            "num_train_epochs": 10,
            "max_seq_length": 128,
            "learning_rate": 1e-5,
            "train_batch_size": 32
        },
        "mrpc": {
            "num_train_epochs": 10,
            "max_seq_length": 128,
            "learning_rate": 1e-5,
            "train_batch_size": 32
        },
        "sts-b": {
            "num_train_epochs": 10,
            "max_seq_length": 128,
            "learning_rate": 2e-5,
            "train_batch_size": 16
        },
        "qqp": {
            "num_train_epochs": 10,
            "max_seq_length": 128,
            "learning_rate": 1e-5,
            "train_batch_size": 32
        },
        "qnli": {
            "num_train_epochs": 10,
            "max_seq_length": 128,
            "learning_rate": 1e-5,
            "train_batch_size": 32
        },
        "rte": {
            "num_train_epochs": 10,
            "max_seq_length": 128,
            "learning_rate": 2e-5,
            "train_batch_size": 16
        }
    }

    acc_tasks = ["mnli", "mrpc", "sst-2", "qqp", "qnli", "rte"]
    corr_tasks = ["sts-b"]
    mcc_tasks = ["cola"]

    # Prepare devices
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    logger.info("device: {} n_gpu: {}".format(device, n_gpu))

    # Prepare seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # Prepare task settings
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        # raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
        if os.path.exists(os.path.join(args.output_dir, "eval_results.txt")):
            os.remove(os.path.join(args.output_dir, "eval_results.txt"))

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if not args.num_train_epochs:
        args.num_train_epochs = default_params[task_name]["num_train_epochs"]
    if not args.learning_rate:
        args.learning_rate = default_params[task_name]["learning_rate"]
    if not args.train_batch_size:
        args.train_batch_size = default_params[task_name]["train_batch_size"]
    if not args.max_seq_length:
        args.max_seq_len = default_params[task_name]["max_seq_length"]

    # print(task_name in default_params, args.num_train_epochs, args.max_seq_length)
    if task_name not in processors:
        raise ValueError("Task not found: %s" % task_name)

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.model,
                                              do_lower_case=args.do_lower_case)

    if not args.do_eval:
        if args.aug_train == 'none':
            train_examples = processor.get_train_examples(
                args.data_dir, args.train_name)
        else:
            train_examples = processor.get_aug_examples(
                args.data_dir, args.aug_train)
        if args.gradient_accumulation_steps < 1:
            raise ValueError(
                "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
                .format(args.gradient_accumulation_steps))

        args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs

        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer, output_mode)
        # train_features = old_convert_examples_to_features(train_examples, label_list,
        #                                               args.max_seq_length, tokenizer, output_mode)
        train_data, _ = get_tensor_data(output_mode, train_features)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

    eval_examples = processor.get_dev_examples(args.data_dir, args.val_name)
    eval_features = convert_examples_to_features(eval_examples, label_list,
                                                 args.max_seq_length,
                                                 tokenizer, output_mode)
    # eval_features = old_convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
    eval_data, eval_labels = get_tensor_data(output_mode, eval_features)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    # load config file from here
    quant_config = BertConfig.from_json_file("config/new_example_config.json")

    # change config if specified in command
    if "quant_group_number" in quant_config.__dict__:
        quant_config.__dict__["quant_group_number"] = args.quant_group_number
    for item in quant_config.__dict__:
        if "layer_bits" in item:
            for b_item in quant_config.__dict__[item]:
                quant_config.__dict__[item][b_item] = args.weight_bit
        elif "embed_bits" in item:
            for b_item in quant_config.__dict__[item]:
                quant_config.__dict__[item][b_item] = args.weight_bit
        elif "activation_bits" in item:
            quant_config.__dict__[item] = args.activation_bit

    model = QBertForSequenceClassification.from_pretrained(
        args.model, num_labels=num_labels, quant_config=quant_config)
    model.to(device)
    if args.do_eval:
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)

        model.eval()
        result = do_eval(model, task_name, eval_dataloader, device,
                         output_mode, eval_labels, num_labels)
        logger.info("***** Eval results *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
    else:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        param_optimizer = list(model.named_parameters())
        size = 0
        for n, p in model.named_parameters():
            logger.info('n: {}'.format(n))
            size += p.nelement()

        logger.info('Total parameters: {}'.format(size))

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=1e-8)
        scheduler = NewWarmupLinearSchedule(
            optimizer,
            warmup_steps=int(args.warmup_proportion *
                             num_train_optimization_steps),
            t_total=num_train_optimization_steps)
        # optimizer = BertAdam(
        #         optimizer_grouped_parameters,
        #         lr=args.learning_rate,
        #         warmup=args.warmup_proportion,
        #         t_total=num_train_optimization_steps)

        # Prepare loss functions
        loss_mse = MSELoss()

        def soft_cross_entropy(predicts, targets):
            student_likelihood = torch.nn.functional.log_softmax(predicts,
                                                                 dim=-1)
            # targets_prob = torch.nn.functional.softmax(targets, dim=-1)
            targets_prob = targets
            return (-targets_prob * student_likelihood).mean()

        # Train and evaluate
        global_step = 0
        best_dev_acc = -1
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")

        for epoch_ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0.
            tr_att_loss = 0.
            tr_rep_loss = 0.
            tr_cls_loss = 0.

            model.train()
            nb_tr_examples, nb_tr_steps = 0, 0

            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration", ascii=True)):
                batch = tuple(t.to(device) for t in batch)

                input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch
                if input_ids.size()[0] != args.train_batch_size:
                    continue

                att_loss = 0.
                rep_loss = 0.
                cls_loss = 0.

                student_logits, student_atts, student_reps = model(
                    input_ids, segment_ids, input_mask, is_student=True)

                # if output_mode == "classification":
                #     loss_fct = CrossEntropyLoss()
                #     cls_loss = loss_fct(student_logits.view(-1, num_labels), label_ids.view(-1))

                if output_mode == "classification":
                    # loss_fct = CrossEntropyLoss()
                    cls_loss = soft_cross_entropy(student_logits, label_ids)

                elif output_mode == "regression":
                    loss_mse = MSELoss()
                    cls_loss = loss_mse(student_logits.view(-1),
                                        label_ids.view(-1))

                loss = cls_loss
                tr_cls_loss += cls_loss.item()

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += label_ids.size(0)
                nb_tr_steps += 1

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
                    global_step += 1

            logger.info("***** Running evaluation *****")
            logger.info("  Epoch = {} iter {} step".format(
                epoch_, global_step))
            logger.info("  Num examples = %d", len(eval_examples))
            logger.info("  Batch size = %d", args.eval_batch_size)

            model.eval()

            loss = tr_loss / (step + 1)
            cls_loss = tr_cls_loss / (step + 1)
            att_loss = tr_att_loss / (step + 1)
            rep_loss = tr_rep_loss / (step + 1)

            result = {}

            result = do_eval(model, task_name, eval_dataloader, device,
                             output_mode, eval_labels, num_labels)
            result['global_step'] = global_step
            result['cls_loss'] = cls_loss
            result['att_loss'] = att_loss
            result['rep_loss'] = rep_loss
            result['loss'] = loss

            result_to_file(result, output_eval_file)

            save_model = True

            if task_name in acc_tasks and result['acc'] > best_dev_acc:
                best_dev_acc = result['acc']
                save_model = True

            if task_name in corr_tasks and result['corr'] > best_dev_acc:
                best_dev_acc = result['corr']
                save_model = True

            if task_name in mcc_tasks and result['mcc'] > best_dev_acc:
                best_dev_acc = result['mcc']
                save_model = True

            if save_model:
                logger.info("***** Save model *****")

                model_to_save = model.module if hasattr(model,
                                                        'module') else model

                model_name = WEIGHTS_NAME
                output_model_file = os.path.join(args.output_dir, model_name)
                output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

                torch.save(model_to_save.state_dict(), output_model_file)
                model_to_save.config.to_json_file(output_config_file)
                tokenizer.save_vocabulary(args.output_dir)

                # Test mnli-mm
                # if task_name == "mnli":
                #     task_name = "mnli"
                #     processor = processors[task_name]()
                #     if not os.path.exists(args.output_dir + '-MM'):
                #         os.makedirs(args.output_dir + '-MM')

                #     eval_examples = processor.get_dev_examples(args.data_dir)

                #     eval_features = convert_examples_to_features(
                #         eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
                #     eval_data, eval_labels = get_tensor_data(output_mode, eval_features)

                #     logger.info("***** Running mm evaluation *****")
                #     logger.info("  Num examples = %d", len(eval_examples))
                #     logger.info("  Batch size = %d", args.eval_batch_size)

                #     eval_sampler = SequentialSampler(eval_data)
                #     eval_dataloader = DataLoader(eval_data, sampler=eval_sampler,
                #                                  batch_size=args.eval_batch_size)

                #     result = do_eval(model, task_name, eval_dataloader,
                #                      device, output_mode, eval_labels, num_labels)

                #     result['global_step'] = global_step

                #     tmp_output_eval_file = os.path.join(args.output_dir + '-MM', "eval_results.txt")
                #     result_to_file(result, tmp_output_eval_file)

                #     task_name = 'mnli'

                # if oncloud:
                #     logging.info(mox.file.list_directory(args.output_dir, recursive=True))
                #     logging.info(mox.file.list_directory('.', recursive=True))
                #     mox.file.copy_parallel(args.output_dir, args.data_url)
                #     mox.file.copy_parallel('.', args.data_url)

            model.train()
Exemple #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir",
        default="data/MNLI",
        type=str,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--teacher_model",
                        default="pretrained/checkpoint-31280/",
                        type=str,
                        help="The teacher model dir.")
    parser.add_argument("--student_model",
                        default="pretrained/generalbert",
                        type=str,
                        help="The student model dir.")
    parser.add_argument("--task_name",
                        default="MNLI",
                        type=str,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default="output",
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=384,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=128,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--weight_decay',
                        '--wd',
                        default=1e-4,
                        type=float,
                        metavar='W',
                        help='weight decay')
    parser.add_argument("--num_train_epochs",
                        default=5.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )

    # added arguments
    parser.add_argument('--aug_train', action='store_true')
    parser.add_argument('--eval_step', type=float, default=0.1)
    parser.add_argument('--pred_distill', action='store_true')
    parser.add_argument('--data_url', type=str, default="")
    parser.add_argument('--temperature', type=float, default=1.)

    args = parser.parse_args()
    logger.info('The args: {}'.format(args))

    # intermediate distillation default parameters
    default_params = {
        "cola": {
            "num_train_epochs": 50,
            "max_seq_length": 64
        },
        "mnli": {
            "num_train_epochs": 5,
            "max_seq_length": 128
        },
        "mrpc": {
            "num_train_epochs": 20,
            "max_seq_length": 128
        },
        "sst-2": {
            "num_train_epochs": 10,
            "max_seq_length": 64
        },
        "sts-b": {
            "num_train_epochs": 20,
            "max_seq_length": 128
        },
        "qqp": {
            "num_train_epochs": 5,
            "max_seq_length": 128
        },
        "qnli": {
            "num_train_epochs": 10,
            "max_seq_length": 128
        },
        "rte": {
            "num_train_epochs": 20,
            "max_seq_length": 128
        }
    }

    acc_tasks = ["mnli", "mrpc", "sst-2", "qqp", "qnli", "rte"]
    corr_tasks = ["sts-b"]
    mcc_tasks = ["cola"]

    # Prepare devices
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    logger.info("device: {} n_gpu: {}".format(device, n_gpu))
    tb = SummaryWriter("./runs")

    # Prepare seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # Prepare task settings
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name in default_params:
        args.max_seq_len = default_params[task_name]["max_seq_length"]

    if not args.pred_distill and not args.do_eval:
        if task_name in default_params:
            args.num_train_epoch = default_params[task_name][
                "num_train_epochs"]

    if task_name not in processors:
        raise ValueError("Task not found: %s" % task_name)

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.student_model,
                                              do_lower_case=args.do_lower_case)

    if not args.do_eval:
        if args.gradient_accumulation_steps < 1:
            raise ValueError(
                "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
                .format(args.gradient_accumulation_steps))

        args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

        train_data, _ = get_tensor_data(args, task_name, tokenizer, False,
                                        args.aug_train)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        num_train_optimization_steps = int(
            len(train_dataloader) /
            args.gradient_accumulation_steps) * args.num_train_epochs

    eval_data, eval_labels = get_tensor_data(args, task_name, tokenizer, True,
                                             False)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    if not args.do_eval:
        teacher_model = TinyBertForSequenceClassification.from_pretrained(
            args.teacher_model, num_labels=num_labels)
        teacher_model.to(device)

    student_model = TinyBertForSequenceClassification.from_pretrained(
        args.student_model, num_labels=num_labels)
    student_model.to(device)
    if args.do_eval:
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_data))
        logger.info("  Batch size = %d", args.eval_batch_size)

        student_model.eval()
        result = do_eval(student_model, task_name, eval_dataloader, device,
                         output_mode, eval_labels, num_labels)
        logger.info("***** Eval results *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
    else:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_data))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        if n_gpu > 1:
            student_model = torch.nn.DataParallel(student_model)
            teacher_model = torch.nn.DataParallel(teacher_model)
        # Prepare optimizer
        param_optimizer = list(student_model.named_parameters())
        size = 0
        for n, p in student_model.named_parameters():
            logger.info('n: {}'.format(n))
            size += p.nelement()

        logger.info('Total parameters: {}'.format(size))
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        schedule = 'warmup_linear'
        if not args.pred_distill:
            schedule = 'none'
        optimizer = BertAdam(optimizer_grouped_parameters,
                             schedule=schedule,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)
        # Prepare loss functions
        loss_mse = MSELoss()

        def soft_cross_entropy(predicts, targets):
            student_likelihood = torch.nn.functional.log_softmax(predicts,
                                                                 dim=-1)
            targets_prob = torch.nn.functional.softmax(targets, dim=-1)
            return (-targets_prob * student_likelihood).mean()

        # Train and evaluate
        global_step = 0
        best_dev_acc = 0.0
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")

        for epoch_ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0.
            tr_att_loss = 0.
            tr_rep_loss = 0.
            tr_cls_loss = 0.

            student_model.train()
            nb_tr_examples, nb_tr_steps = 0, 0

            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration", ascii=True)):
                batch = tuple(t.to(device) for t in batch)

                input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch
                if input_ids.size()[0] != args.train_batch_size:
                    continue

                att_loss = 0.
                rep_loss = 0.
                cls_loss = 0.

                student_logits, student_atts, student_reps = student_model(
                    input_ids, segment_ids, input_mask, is_student=True)

                with torch.no_grad():
                    teacher_logits, teacher_atts, teacher_reps = teacher_model(
                        input_ids, segment_ids, input_mask)

                if not args.pred_distill:
                    teacher_layer_num = len(teacher_atts)
                    student_layer_num = len(student_atts)
                    assert teacher_layer_num % student_layer_num == 0
                    layers_per_block = int(teacher_layer_num /
                                           student_layer_num)
                    new_teacher_atts = [
                        teacher_atts[i * layers_per_block + layers_per_block -
                                     1] for i in range(student_layer_num)
                    ]

                    for student_att, teacher_att in zip(
                            student_atts, new_teacher_atts):
                        student_att = torch.where(
                            student_att <= -1e2,
                            torch.zeros_like(student_att).to(device),
                            student_att)
                        teacher_att = torch.where(
                            teacher_att <= -1e2,
                            torch.zeros_like(teacher_att).to(device),
                            teacher_att)

                        tmp_loss = loss_mse(student_att, teacher_att)
                        att_loss += tmp_loss

                    new_teacher_reps = [
                        teacher_reps[i * layers_per_block]
                        for i in range(student_layer_num + 1)
                    ]
                    new_student_reps = student_reps
                    for student_rep, teacher_rep in zip(
                            new_student_reps, new_teacher_reps):
                        tmp_loss = loss_mse(student_rep, teacher_rep)
                        rep_loss += tmp_loss

                    loss = rep_loss + att_loss
                    tr_att_loss += att_loss.item()
                    tr_rep_loss += rep_loss.item()
                else:
                    if output_mode == "classification":
                        cls_loss = soft_cross_entropy(
                            student_logits / args.temperature,
                            teacher_logits / args.temperature)
                    elif output_mode == "regression":
                        loss_mse = MSELoss()
                        cls_loss = loss_mse(student_logits.view(-1),
                                            label_ids.view(-1))

                    loss = cls_loss
                    tr_cls_loss += cls_loss.item()

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()
                tb.add_scalar("loss", loss.item(), global_step)
                tr_loss += loss.item()
                nb_tr_examples += label_ids.size(0)
                nb_tr_steps += 1

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if (global_step + 1) % int(
                        args.eval_step * num_train_optimization_steps) == 0:
                    logger.info("***** Running evaluation *****")
                    logger.info("  Epoch = {} iter {} step".format(
                        epoch_, global_step))
                    logger.info("  Num examples = %d", len(eval_data))
                    logger.info("  Batch size = %d", args.eval_batch_size)

                    student_model.eval()

                    loss = tr_loss / (step + 1)
                    cls_loss = tr_cls_loss / (step + 1)
                    att_loss = tr_att_loss / (step + 1)
                    rep_loss = tr_rep_loss / (step + 1)

                    result = {}
                    if args.pred_distill:
                        result = do_eval(student_model, task_name,
                                         eval_dataloader, device, output_mode,
                                         eval_labels, num_labels)
                    result['global_step'] = global_step
                    result['cls_loss'] = cls_loss
                    result['att_loss'] = att_loss
                    result['rep_loss'] = rep_loss
                    result['loss'] = loss

                    result_to_file(result, output_eval_file)

                    if not args.pred_distill:
                        save_model = True
                    else:
                        save_model = False

                        if task_name in acc_tasks and result[
                                'acc'] > best_dev_acc:
                            best_dev_acc = result['acc']
                            save_model = True

                        if task_name in corr_tasks and result[
                                'corr'] > best_dev_acc:
                            best_dev_acc = result['corr']
                            save_model = True

                        if task_name in mcc_tasks and result[
                                'mcc'] > best_dev_acc:
                            best_dev_acc = result['mcc']
                            save_model = True

                    if save_model:
                        logger.info("***** Save model *****")

                        model_to_save = student_model.module if hasattr(
                            student_model, 'module') else student_model

                        model_name = WEIGHTS_NAME
                        # if not args.pred_distill:
                        #     model_name = "step_{}_{}".format(global_step, WEIGHTS_NAME)
                        output_model_file = os.path.join(
                            args.output_dir, model_name)
                        output_config_file = os.path.join(
                            args.output_dir, CONFIG_NAME)

                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        model_to_save.config.to_json_file(output_config_file)
                        tokenizer.save_vocabulary(args.output_dir)

                        # Test mnli-mm
                        if args.pred_distill and task_name == "mnli":
                            task_name = "mnli-mm"
                            if not os.path.exists(args.output_dir + '-MM'):
                                os.makedirs(args.output_dir + '-MM')

                            eval_data, eval_labels = get_tensor_data(
                                args, task_name, tokenizer, True, False)

                            eval_sampler = SequentialSampler(eval_data)
                            eval_dataloader = DataLoader(
                                eval_data,
                                sampler=eval_sampler,
                                batch_size=args.eval_batch_size)
                            logger.info("***** Running mm evaluation *****")
                            logger.info("  Num examples = %d", len(eval_data))
                            logger.info("  Batch size = %d",
                                        args.eval_batch_size)

                            result = do_eval(student_model, task_name,
                                             eval_dataloader, device,
                                             output_mode, eval_labels,
                                             num_labels)

                            result['global_step'] = global_step

                            tmp_output_eval_file = os.path.join(
                                args.output_dir + '-MM', "eval_results.txt")
                            result_to_file(result, tmp_output_eval_file)

                            task_name = 'mnli'

                    student_model.train()
def main():
    parser = ArgumentParser()
    parser.add_argument(
        '--pregenerated_data',
        type=str,
        required=True,
        default='/nas/hebin/data/english-exp/books_wiki_tokens_ngrams')
    parser.add_argument('--s3_output_dir', type=str, default='huawei_yun')
    parser.add_argument('--student_model',
                        type=str,
                        default='8layer_bert',
                        required=True)
    parser.add_argument('--teacher_model', type=str, default='electra_base')
    parser.add_argument('--cache_dir', type=str, default='/cache', help='')

    parser.add_argument("--epochs",
                        type=int,
                        default=2,
                        help="Number of epochs to train for")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=1e-4,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--max_seq_length", type=int, default=512)

    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--scratch',
                        action='store_true',
                        help="Whether to train from scratch")
    parser.add_argument(
        "--reduce_memory",
        action="store_true",
        help=
        "Store training data as on-disc memmaps to massively reduce memory usage"
    )
    parser.add_argument('--debug',
                        action='store_true',
                        help="Whether to debug")

    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )

    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument("--already_trained_epoch", default=0, type=int)
    parser.add_argument(
        "--masked_lm_prob",
        type=float,
        default=0.0,
        help="Probability of masking each token for the LM task")
    parser.add_argument(
        "--max_predictions_per_seq",
        type=int,
        default=77,
        help="Maximum number of tokens to mask in each sequence")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--logging_steps",
                        type=int,
                        default=500,
                        help="Log every X updates steps.")
    parser.add_argument("--warmup_steps",
                        default=10000,
                        type=int,
                        help="Linear warmup over warmup_steps.")

    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")

    parser.add_argument("--num_workers",
                        type=int,
                        default=4,
                        help="num_workers.")
    parser.add_argument("--continue_index", type=int, default=0, help="")
    parser.add_argument("--threads",
                        type=int,
                        default=27,
                        help="Number of threads to preprocess input data")

    # Search space for sub_bart architecture
    parser.add_argument('--layer_num_space',
                        nargs='+',
                        type=int,
                        default=[1, 8])
    parser.add_argument('--hidden_size_space',
                        nargs='+',
                        type=int,
                        default=[128, 768])
    parser.add_argument('--qkv_size_space',
                        nargs='+',
                        type=int,
                        default=[180, 768])
    parser.add_argument('--intermediate_size_space',
                        nargs='+',
                        type=int,
                        default=[128, 3072])
    parser.add_argument('--head_num_space',
                        nargs='+',
                        type=int,
                        default=[1, 12])
    parser.add_argument('--sample_times_per_batch', type=int, default=1)
    parser.add_argument('--further_train', action='store_true')
    parser.add_argument('--mlm_loss', action='store_true')

    # Argument for Huawei yun
    parser.add_argument('--data_url', type=str, default='', help='s3 url')
    parser.add_argument("--train_url", type=str, default="", help="s3 url")

    args = parser.parse_args()

    assert (torch.cuda.is_available())
    device_count = torch.cuda.device_count()
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))

    # Call the init process
    # init_method = 'tcp://'
    init_method = ''
    master_ip = os.getenv('MASTER_ADDR', 'localhost')
    master_port = os.getenv('MASTER_PORT', '6000')
    init_method += master_ip + ':' + master_port

    # Manually set the device ids.
    # if device_count > 0:
    # args.local_rank = args.rank % device_count
    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    print('device_id: %s' % args.local_rank)
    print('device_count: %s, rank: %s, world_size: %s' %
          (device_count, args.rank, args.world_size))
    print(init_method)

    torch.distributed.init_process_group(backend='nccl',
                                         world_size=args.world_size,
                                         rank=args.rank,
                                         init_method=init_method)

    LOCAL_DIR = args.cache_dir
    if oncloud:
        assert mox.file.exists(LOCAL_DIR)

    if args.local_rank == 0 and oncloud:
        logging.info(
            mox.file.list_directory(args.pregenerated_data, recursive=True))
        logging.info(
            mox.file.list_directory(args.student_model, recursive=True))

    local_save_dir = os.path.join(LOCAL_DIR, 'output', 'superbert',
                                  'checkpoints')
    local_tsbd_dir = os.path.join(LOCAL_DIR, 'output', 'superbert',
                                  'tensorboard')
    save_name = '_'.join([
        'superbert',
        'epoch',
        str(args.epochs),
        'lr',
        str(args.learning_rate),
        'bsz',
        str(args.train_batch_size),
        'grad_accu',
        str(args.gradient_accumulation_steps),
        str(args.max_seq_length),
        'gpu',
        str(args.world_size),
    ])
    bash_save_dir = os.path.join(local_save_dir, save_name)
    bash_tsbd_dir = os.path.join(local_tsbd_dir, save_name)
    if args.local_rank == 0:
        if not os.path.exists(bash_save_dir):
            os.makedirs(bash_save_dir)
            logger.info(bash_save_dir + ' created!')
        if not os.path.exists(bash_tsbd_dir):
            os.makedirs(bash_tsbd_dir)
            logger.info(bash_tsbd_dir + ' created!')

    local_data_dir_tmp = '/cache/data/tmp/'
    local_data_dir = local_data_dir_tmp + save_name

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    args.tokenizer = BertTokenizer.from_pretrained(
        args.student_model, do_lower_case=args.do_lower_case)
    args.vocab_list = list(args.tokenizer.vocab.keys())

    config = BertConfig.from_pretrained(
        os.path.join(args.student_model, CONFIG_NAME))
    logger.info("Model config {}".format(config))

    if args.further_train:
        if args.mlm_loss:
            student_model = SuperBertForPreTraining.from_pretrained(
                args.student_model, config)
        else:
            student_model = SuperTinyBertForPreTraining.from_pretrained(
                args.student_model, config)
    else:
        if args.mlm_loss:
            student_model = SuperBertForPreTraining.from_scratch(
                args.student_model, config)
        else:
            student_model = SuperTinyBertForPreTraining.from_scratch(
                args.student_model, config)

    student_model.to(device)

    if not args.mlm_loss:
        teacher_model = BertModel.from_pretrained(args.teacher_model)
        teacher_model.to(device)

    # build arch space
    min_hidden_size, max_hidden_size = args.hidden_size_space
    min_ffn_size, max_ffn_size = args.intermediate_size_space
    min_qkv_size, max_qkv_size = args.qkv_size_space
    min_head_num, max_head_num = args.head_num_space

    hidden_step = 4
    ffn_step = 4
    qkv_step = 12
    head_step = 1

    number_hidden_step = int((max_hidden_size - min_hidden_size) / hidden_step)
    number_ffn_step = int((max_ffn_size - min_ffn_size) / ffn_step)
    number_qkv_step = int((max_qkv_size - min_qkv_size) / qkv_step)
    number_head_step = int((max_head_num - min_head_num) / head_step)

    layer_numbers = list(
        range(args.layer_num_space[0], args.layer_num_space[1] + 1))
    hidden_sizes = [
        i * hidden_step + min_hidden_size
        for i in range(number_hidden_step + 1)
    ]
    ffn_sizes = [
        i * ffn_step + min_ffn_size for i in range(number_ffn_step + 1)
    ]
    qkv_sizes = [
        i * qkv_step + min_qkv_size for i in range(number_qkv_step + 1)
    ]
    head_numbers = [
        i * head_step + min_head_num for i in range(number_head_step + 1)
    ]

    ######
    if args.local_rank == 0:
        tb_writer = SummaryWriter(bash_tsbd_dir)

    global_step = 0
    step = 0
    tr_loss, tr_rep_loss, tr_att_loss = 0.0, 0.0, 0.0
    logging_loss, rep_logging_loss, att_logging_loss = 0.0, 0.0, 0.0
    end_time, start_time = 0, 0

    submodel_config = dict()

    if args.further_train:
        submodel_config['sample_layer_num'] = config.num_hidden_layers
        submodel_config['sample_hidden_size'] = config.hidden_size
        submodel_config[
            'sample_intermediate_sizes'] = config.num_hidden_layers * [
                config.intermediate_size
            ]
        submodel_config[
            'sample_num_attention_heads'] = config.num_hidden_layers * [
                config.num_attention_heads
            ]
        submodel_config['sample_qkv_sizes'] = config.num_hidden_layers * [
            config.qkv_size
        ]

    for epoch in range(args.epochs):
        if epoch < args.continue_index:
            args.warmup_steps = 0
            continue

        args.local_data_dir = os.path.join(local_data_dir, str(epoch))
        if args.local_rank == 0:
            os.makedirs(args.local_data_dir)
        while 1:
            if os.path.exists(args.local_data_dir):
                epoch_dataset = load_doc_tokens_ngrams(args)
                break

        if args.local_rank == 0 and oncloud:
            logging.info('Dataset in epoch %s', epoch)
            logging.info(
                mox.file.list_directory(args.local_data_dir, recursive=True))

        train_sampler = DistributedSampler(epoch_dataset,
                                           num_replicas=1,
                                           rank=0)

        train_dataloader = DataLoader(epoch_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        step_in_each_epoch = len(
            train_dataloader) // args.gradient_accumulation_steps
        num_train_optimization_steps = step_in_each_epoch * args.epochs
        logging.info("***** Running training *****")
        logging.info("  Num examples = %d",
                     len(epoch_dataset) * args.world_size)
        logger.info("  Num Epochs = %d", args.epochs)
        logging.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            args.train_batch_size * args.gradient_accumulation_steps *
            args.world_size)
        logger.info("  Gradient Accumulation steps = %d",
                    args.gradient_accumulation_steps)
        logging.info("  Num steps = %d", num_train_optimization_steps)

        if epoch == args.continue_index:
            # Prepare optimizer
            param_optimizer = list(student_model.named_parameters())
            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.01
            }, {
                'params': [
                    p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]

            warm_up_ratio = args.warmup_steps / num_train_optimization_steps
            print('warm_up_ratio: {}'.format(warm_up_ratio))
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 e=args.adam_epsilon,
                                 schedule='warmup_linear',
                                 t_total=num_train_optimization_steps,
                                 warmup=warm_up_ratio)

            if args.fp16:
                try:
                    from apex import amp
                except ImportError:
                    raise ImportError(
                        "Please install apex from https://www.github.com/nvidia/apex"
                        " to use fp16 training.")
                student_model, optimizer = amp.initialize(
                    student_model,
                    optimizer,
                    opt_level=args.fp16_opt_level,
                    min_loss_scale=1)  #

            # apex
            student_model = DDP(
                student_model,
                message_size=10000000,
                gradient_predivide_factor=torch.distributed.get_world_size(),
                delay_allreduce=True)

            if not args.mlm_loss:
                teacher_model = DDP(teacher_model,
                                    message_size=10000000,
                                    gradient_predivide_factor=torch.
                                    distributed.get_world_size(),
                                    delay_allreduce=True)
                teacher_model.eval()

            logger.info('apex data paralleled!')

        from torch.nn import MSELoss
        loss_mse = MSELoss()

        student_model.train()
        for step_, batch in enumerate(train_dataloader):
            step += 1
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_masks, lm_label_ids = batch

            if not args.mlm_loss:
                teacher_last_rep, teacher_last_att = teacher_model(
                    input_ids, input_masks)
                teacher_last_att = torch.where(
                    teacher_last_att <= -1e2,
                    torch.zeros_like(teacher_last_att).to(device),
                    teacher_last_att)
                teacher_last_rep.detach()
                teacher_last_att.detach()

            for sample_idx in range(args.sample_times_per_batch):
                att_loss = 0.
                rep_loss = 0.
                rand_seed = int(global_step * args.world_size +
                                sample_idx)  # + args.rank % args.world_size)

                if not args.mlm_loss:
                    if not args.further_train:
                        submodel_config = sample_arch_4_kd(
                            layer_numbers,
                            hidden_sizes,
                            ffn_sizes,
                            qkv_sizes,
                            reset_rand_seed=True,
                            rand_seed=rand_seed)
                    # knowledge distillation
                    student_last_rep, student_last_att = student_model(
                        input_ids, submodel_config, attention_mask=input_masks)
                    student_last_att = torch.where(
                        student_last_att <= -1e2,
                        torch.zeros_like(student_last_att).to(device),
                        student_last_att)

                    att_loss += loss_mse(student_last_att, teacher_last_att)
                    rep_loss += loss_mse(student_last_rep, teacher_last_rep)
                    loss = att_loss + rep_loss

                    if args.gradient_accumulation_steps > 1:
                        rep_loss = rep_loss / args.gradient_accumulation_steps
                        att_loss = att_loss / args.gradient_accumulation_steps
                        loss = loss / args.gradient_accumulation_steps

                    tr_rep_loss += rep_loss.item()
                    tr_att_loss += att_loss.item()
                else:
                    if not args.further_train:
                        submodel_config = sample_arch_4_mlm(
                            layer_numbers,
                            hidden_sizes,
                            ffn_sizes,
                            head_numbers,
                            reset_rand_seed=True,
                            rand_seed=rand_seed)
                    loss = student_model(input_ids,
                                         submodel_config,
                                         attention_mask=input_masks,
                                         masked_lm_labels=lm_label_ids)

                tr_loss += loss.item()
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward(retain_graph=True)
                else:
                    loss.backward(retain_graph=True)

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(student_model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

                if (step + 1) % (args.gradient_accumulation_steps * args.logging_steps) == 0 \
                        and args.local_rank < 2 or global_step < 100:
                    end_time = time.time()

                    if not args.mlm_loss:
                        logger.info(
                            'Epoch: %s, global_step: %s/%s, lr: %s, loss is %s; '
                            'rep_loss is %s; att_loss is %s; (%.2f sec)' %
                            (epoch, global_step + 1, step_in_each_epoch,
                             optimizer.get_lr()[0],
                             loss.item() * args.gradient_accumulation_steps,
                             rep_loss.item() *
                             args.gradient_accumulation_steps, att_loss.item()
                             * args.gradient_accumulation_steps,
                             end_time - start_time))
                    else:
                        logger.info(
                            'Epoch: %s, global_step: %s/%s, lr: %s, loss is %s; '
                            ' (%.2f sec)' %
                            (epoch, global_step + 1, step_in_each_epoch,
                             optimizer.get_lr()[0],
                             loss.item() * args.gradient_accumulation_steps,
                             end_time - start_time))
                    start_time = time.time()

                if args.logging_steps > 0 and global_step % args.logging_steps == 0 and args.local_rank == 0:
                    tb_writer.add_scalar("lr",
                                         optimizer.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)

                    if not args.mlm_loss:
                        tb_writer.add_scalar("rep_loss",
                                             (tr_rep_loss - rep_logging_loss) /
                                             args.logging_steps, global_step)
                        tb_writer.add_scalar("att_loss",
                                             (tr_att_loss - att_logging_loss) /
                                             args.logging_steps, global_step)
                        rep_logging_loss = tr_rep_loss
                        att_logging_loss = tr_att_loss

                    logging_loss = tr_loss

        # Save a trained model
        if args.rank == 0:
            saving_path = bash_save_dir
            saving_path = Path(os.path.join(saving_path,
                                            "epoch_" + str(epoch)))

            if saving_path.is_dir() and list(saving_path.iterdir()):
                logging.warning(
                    f"Output directory ({ saving_path }) already exists and is not empty!"
                )
            saving_path.mkdir(parents=True, exist_ok=True)

            logging.info("** ** * Saving fine-tuned model ** ** * ")
            model_to_save = student_model.module if hasattr(student_model, 'module')\
                else student_model  # Only save the model it-self

            output_model_file = os.path.join(saving_path, WEIGHTS_NAME)
            output_config_file = os.path.join(saving_path, CONFIG_NAME)

            torch.save(model_to_save.state_dict(), output_model_file)
            model_to_save.config.to_json_file(output_config_file)
            args.tokenizer.save_vocabulary(saving_path)

            torch.save(optimizer.state_dict(),
                       os.path.join(saving_path, "optimizer.pt"))
            logger.info("Saving optimizer and scheduler states to %s",
                        saving_path)

            # debug
            if oncloud:
                local_output_dir = os.path.join(LOCAL_DIR, 'output')
                logger.info(
                    mox.file.list_directory(local_output_dir, recursive=True))
                logger.info('s3_output_dir: ' + args.s3_output_dir)
                mox.file.copy_parallel(local_output_dir, args.s3_output_dir)

    if args.local_rank == 0:
        tb_writer.close()
Exemple #12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--job_id", default='tmp', type=str, help='Jobid to save training logs')
    parser.add_argument("--data_dir",default=None,type=str,help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--teacher_model",default=None,type=str,help="The teacher model dir.")
    parser.add_argument("--student_model",default=None,type=str,help="The student model dir.")
    parser.add_argument("--output_dir",default='output',type=str,help="The output directory where the model predictions and checkpoints will be written.")

    # default params for SQuAD
    parser.add_argument('--version_2_with_negative', 
                        action='store_true')
    parser.add_argument("--max_seq_length",
                        default=384,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--doc_stride", default=128, type=int,
                        help="When splitting up a long document into chunks, how much stride to take between chunks.")
    parser.add_argument("--max_query_length", default=64, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--n_best_size", default=20, type=int,
                        help="The total number of n-best predictions to generate in the nbest_predictions.json "
                             "output file.")
    parser.add_argument("--max_answer_length", default=30, type=int,
                        help="The maximum length of an answer that can be generated. This is needed because the start "
                             "and end predictions are not conditioned on one another.")
    parser.add_argument('--null_score_diff_threshold',
                        type=float, default=0.0,
                        help="If null_score - best_non_null is greater than the threshold predict null.")
    
    parser.add_argument("--batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--weight_decay', '--wd',
                        default=1e-4,
                        type=float,
                        metavar='W',
                        help='weight decay')
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument('--eval_step', type=int, default=200)
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    
    parser.add_argument('--do_eval',default = 0,type=int)
     # distillation params
    parser.add_argument('--aug_train', action='store_true',
                        help="Whether using data augmentation or not")
    parser.add_argument('--kd_type', default='no_kd', choices=['no_kd', 'two_stage', 'logit_kd', 'joint_kd'],
                        help="choose one of the kd type")
    parser.add_argument('--distill_logit', action='store_true',
                        help="Whether using distillation over logits or not")
    parser.add_argument('--distill_rep_attn', action='store_true',
                        help="Whether using distillation over reps and attns or not")
    parser.add_argument('--temperature', type=float, default=1.)
    # quantization params
    parser.add_argument("--weight_bits", default=32, type=int, help="number of bits for weight")
    parser.add_argument("--weight_quant_method", default='twn', type=str,
                        choices=['twn', 'bwn', 'uniform', 'laq'],
                        help="weight quantization methods, can be bwn, twn, laq")
    parser.add_argument("--input_bits",  default=32, type=int,
                        help="number of bits for activation")
    parser.add_argument("--input_quant_method", default='uniform', type=str, choices=['uniform', 'lsq'],
                        help="weight quantization methods, can be bwn, twn, or symmetric quantization for default")

    parser.add_argument('--learnable_scaling', action='store_true', default=True)
    parser.add_argument("--ACT2FN", default='gelu', type=str,
                        help='activation fn for ffn-mid. A8 uses uq + gelu; A4 uses lsq + relu.')
    # training config
    parser.add_argument('--sym_quant_ffn_attn', action='store_true',
                        help='whether use sym quant for attn score and ffn after act') # default asym
    parser.add_argument('--sym_quant_qkvo', action='store_true',  default=True,
                        help='whether use asym quant for Q/K/V and others.') # default sym
    # layerwise quantization config
    parser.add_argument('--clip_init_file', default='threshold_std.pkl', help='files to restore init clip values.')
    parser.add_argument('--clip_init_val', default=2.5, type=float, help='init value of clip_vals, default to (-2.5, +2.5).')
    parser.add_argument('--clip_lr', default=1e-4, type=float, help='Use a seperate lr for clip_vals / stepsize')
    parser.add_argument('--clip_wd', default=0.0, type=float, help='weight decay for clip_vals / stepsize')

    # layerwise quantization config
    parser.add_argument('--embed_layerwise', default=False, type=lambda x: bool(int(x)))
    parser.add_argument('--weight_layerwise', default=True, type=lambda x: bool(int(x)))
    parser.add_argument('--input_layerwise', default=True, type=lambda x: bool(int(x)))

    ### spliting
    parser.add_argument('--split', action='store_true',
                        help='whether to conduct tws spliting. NOTE this is only for training binarybert')
    parser.add_argument('--is_binarybert', action='store_true',
                        help='whether to use binarybert modelling.')

    args = parser.parse_args()

    log_dir = os.path.join(args.output_dir, 'record_%s.log' % args.job_id)
    init_logging(log_dir)
    print_args(vars(args))

    # Prepare devices
    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()
    logging.info("device: {} n_gpu: {}".format(device, n_gpu))

    # Prepare seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    tokenizer = BertTokenizer.from_pretrained(args.teacher_model, do_lower_case=True)
    config = BertConfig.from_pretrained(args.teacher_model)
    config.num_labels = 2

    student_config = copy.deepcopy(config)
    student_config.weight_bits = args.weight_bits
    student_config.input_bits = args.input_bits
    student_config.weight_quant_method = args.weight_quant_method
    student_config.input_quant_method = args.input_quant_method
    student_config.clip_init_val = args.clip_init_val
    student_config.learnable_scaling = args.learnable_scaling
    student_config.sym_quant_qkvo = args.sym_quant_qkvo
    student_config.sym_quant_ffn_attn = args.sym_quant_ffn_attn
    student_config.embed_layerwise = args.embed_layerwise
    student_config.weight_layerwise = args.weight_layerwise
    student_config.input_layerwise = args.input_layerwise
    student_config.hidden_act = args.ACT2FN
    logging.info("***** Training data *****")
    input_file = 'train-v2.0.json' if args.version_2_with_negative else 'train-v1.1.json'
    input_file = os.path.join(args.data_dir,input_file)

    if os.path.exists(input_file+'.features.pkl'):
        logging.info("  loading from cache %s", input_file+'.features.pkl')
        train_features = pickle.load(open(input_file+'.features.pkl', 'rb'))
    else:
        _, train_examples = read_squad_examples(input_file=input_file, is_training=True, 
                                                version_2_with_negative=args.version_2_with_negative)
        train_features = convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=True)
        pickle.dump(train_features, open(input_file+'.features.pkl','wb'))
        
    args.batch_size = args.batch_size // args.gradient_accumulation_steps
    num_train_optimization_steps = int(
        len(train_features) / args.batch_size / args.gradient_accumulation_steps) * args.num_train_epochs

    logging.info("  Num examples = %d", len(train_features))
    logging.info("  Num total steps = %d", num_train_optimization_steps)
    all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
    all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
    all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                all_start_positions, all_end_positions)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size)

    logging.info("***** Evaluation data *****")
    input_file = 'dev-v2.0.json' if args.version_2_with_negative else 'dev-v1.1.json'
    args.dev_file = os.path.join(args.data_dir,input_file)
    dev_dataset, eval_examples = read_squad_examples(
                        input_file=args.dev_file, is_training=False,
                        version_2_with_negative=args.version_2_with_negative)
    eval_features = convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        doc_stride=args.doc_stride,
        max_query_length=args.max_query_length,
        is_training=False)

    logging.info("  Num examples = %d", len(eval_features))

    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)

    if not args.do_eval:
        from transformer.modeling_dynabert import BertForQuestionAnswering
        teacher_model = BertForQuestionAnswering.from_pretrained(args.teacher_model, config = config)
        teacher_model.to(device)
        if n_gpu > 1:
            teacher_model = torch.nn.DataParallel(teacher_model)

    if args.split:
        # rename the checkpoint to restore
        split_model_dir = os.path.join(args.output_dir,'binary_model_init')
        if not os.path.exists(split_model_dir):
            os.mkdir(split_model_dir)
        # copy the json file, avoid over-writing
        source_model_dir = os.path.join(args.student_model, CONFIG_NAME)
        target_model_dir = os.path.join(split_model_dir, CONFIG_NAME)
        os.system('cp -v %s %s' % (source_model_dir, target_model_dir))

        # create the split model ckpt
        source_model_dir = os.path.join(args.student_model, WEIGHTS_NAME)
        target_model_dir = os.path.join(split_model_dir, WEIGHTS_NAME)
        target_model_dir = tws_split(source_model_dir, target_model_dir)
        args.student_model = split_model_dir  # over-write student_model dir
        print("transformed binary model stored at: {}".format(target_model_dir))

    if args.is_binarybert:
        from transformer.modeling_dynabert_binary import BertForQuestionAnswering
        student_model = BertForQuestionAnswering.from_pretrained(args.student_model, config=student_config)
    else:
        from transformer.modeling_dynabert_quant import BertForQuestionAnswering
        student_model = BertForQuestionAnswering.from_pretrained(args.student_model, config=student_config)
    student_model.to(device)
    if n_gpu > 1:
        student_model = torch.nn.DataParallel(student_model)

    learner = KDLearner(args, device, student_model, teacher_model,num_train_optimization_steps)

    if args.do_eval:
        """ evaluation """
        learner.eval(student_model, eval_dataloader, eval_features, eval_examples, dev_dataset)
        return 0

    """ perform training """
    if args.kd_type == 'joint_kd':
        learner.args.distill_logit = True
        learner.args.distill_rep_attn = True
        learner.build()
        learner.train(train_dataloader, eval_dataloader, eval_features, eval_examples, dev_dataset)

    elif args.kd_type == 'logit_kd':
        # only perform the logits kd
        learner.args.distill_logit = True
        learner.args.distill_rep_attn = False
        learner.build(lr=args.learning_rate)
        learner.train(train_dataloader, eval_dataloader, eval_features, eval_examples, dev_dataset)

    elif args.kd_type == 'two_stage':
        # stage 1: intermediate layer distillation
        learner.args.distill_logit = False
        learner.args.distill_rep_attn = True
        learner.build(lr=2.5*args.learning_rate)
        learner.train(train_dataloader, eval_dataloader, eval_features, eval_examples, dev_dataset)

        # stage 2: prediction layer distillation
        learner.student_model.load_state_dict(torch.load(os.path.join(learner.output_dir,WEIGHTS_NAME)))
        learner.args.distill_logit = True
        learner.args.distill_rep_attn = False

        learner.build(lr=args.learning_rate)  # prepare the optimizer again.
        learner.train(train_dataloader, eval_dataloader, eval_features, eval_examples, dev_dataset)

    else:
        assert args.kd_type == 'no_kd'
        # NO kd training, vanilla cross entropy with hard label
        learner.build(lr=args.learning_rate)  # prepare the optimizer again.
        learner.train(train_dataloader, eval_dataloader, eval_features, eval_examples, dev_dataset)

    del learner
    return 0
Exemple #13
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--student_model",
                        default=None,
                        type=str,
                        help="The student model dir.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_predict",
                        action='store_true',
                        help="Whether to run eval on the test set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--data_url', type=str, default="")
    parser.add_argument('--temperature', type=float, default=1.)

    args = parser.parse_args()
    logger.info('The args: {}'.format(args))

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mnli-mm": MnliMismatchedProcessor,
        "mrpc": MrpcProcessor,
        "sst-2": Sst2Processor,
        "sts-b": StsbProcessor,
        "qqp": QqpProcessor,
        "qnli": QnliProcessor,
        "rte": RteProcessor,
        "wnli": WnliProcessor
    }

    output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification",
        "wnli": "classification"
    }

    # intermediate distillation default parameters
    default_params = {
        "cola": {
            "num_train_epochs": 50,
            "max_seq_length": 64,
            'train_batch_size': 32
        },
        "mnli": {
            "num_train_epochs": 5,
            "max_seq_length": 128,
            'train_batch_size': 64
        },
        "mrpc": {
            "num_train_epochs": 20,
            "max_seq_length": 128,
            'train_batch_size': 32
        },
        "sst-2": {
            "num_train_epochs": 10,
            "max_seq_length": 64,
            'train_batch_size': 32
        },
        "sts-b": {
            "num_train_epochs": 20,
            "max_seq_length": 128,
            'train_batch_size': 32
        },
        "qqp": {
            "num_train_epochs": 5,
            "max_seq_length": 128,
            'train_batch_size': 64
        },
        "qnli": {
            "num_train_epochs": 10,
            "max_seq_length": 128,
            'train_batch_size': 64
        },
        "rte": {
            "num_train_epochs": 20,
            "max_seq_length": 128,
            'train_batch_size': 32
        }
    }

    acc_tasks = ["mnli", "mrpc", "sst-2", "qqp", "qnli", "rte"]
    corr_tasks = ["sts-b"]
    mcc_tasks = ["cola"]

    # Prepare devices
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    logger.info("device: {} n_gpu: {}".format(device, n_gpu))

    # Prepare seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    task_name = args.task_name.lower()

    # Prepare task settings
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    if task_name in default_params:
        args.max_seq_length = default_params[task_name]["max_seq_length"]

    if task_name not in processors:
        raise ValueError("Task not found: %s" % task_name)

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)
    tokenizer = BertTokenizer.from_pretrained(args.student_model,
                                              do_lower_case=args.do_lower_case)

    student_model = PrunBertForSequenceClassification.from_pretrained(
        args.student_model, num_labels=num_labels)
    student_model.to(device)
    student_model.eval()

    if args.do_eval:
        eval_dataloader, num_eval_examples, eval_labels = build_dataloader(
            'dev', args, processor, label_list, tokenizer, output_mode)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", num_eval_examples)
        logger.info("  Batch size = %d", args.eval_batch_size)
        result = do_eval(student_model, task_name, eval_dataloader, device,
                         output_mode, eval_labels, num_labels)
        logger.info("***** Eval results *****")
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        result_to_file(result, output_eval_file)

        if task_name == "mnli":
            task_name = "mnli-mm"
            processor = processors[task_name]()
            eval_dataloader, num_eval_examples, eval_labels = build_dataloader(
                'dev', args, processor, label_list, tokenizer, output_mode)
            logger.info("***** Running mm evaluation *****")
            logger.info("  Num examples = %d", num_eval_examples)
            logger.info("  Batch size = %d", args.eval_batch_size)
            result = do_eval(student_model, task_name, eval_dataloader, device,
                             output_mode, eval_labels, num_labels)
            output_eval_file = os.path.join(args.output_dir,
                                            "eval_results-mm.txt")
            result_to_file(result, output_eval_file)
            task_name = "mnli"

    if args.do_predict:
        processor = processors[task_name]()
        test_dataloader, num_test_examples, _ = build_dataloader(
            'test', args, processor, label_list, tokenizer, output_mode)
        logger.info("***** Running prediction *****")
        logger.info("  Num examples = %d", num_test_examples)
        logger.info("  Batch size = %d", args.eval_batch_size)
        predictions = do_predict(student_model, task_name, test_dataloader,
                                 device, output_mode, num_labels)
        label_list = processor.get_labels()
        write_predictions(predictions, args, task_name, output_mode,
                          label_list)

        if task_name == "mnli":
            task_name = "mnli-mm"
            processor = processors[task_name]()
            test_dataloader, num_test_examples, _ = build_dataloader(
                'test', args, processor, label_list, tokenizer, output_mode)

            logger.info("***** Running mm prediction *****")
            logger.info("  Num examples = %d", num_test_examples)
            logger.info("  Batch size = %d", args.eval_batch_size)
            predictions = do_predict(student_model, task_name, test_dataloader,
                                     device, output_mode, num_labels)
            write_predictions(predictions, args, task_name, output_mode,
                              label_list)
            task_name = 'mnli'
Exemple #14
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--train_corpus', type=Path, required=True)
    parser.add_argument("--output_dir", type=Path, required=True)
    parser.add_argument("--bert_model", type=str, required=True)
    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument("--max_seq_len", type=int, default=128)
    parser.add_argument(
        "--reduce_memory",
        action="store_true",
        help=
        "Reduce memory usage for large datasets by keeping data on disc rather than in memory"
    )

    parser.add_argument("--num_workers",
                        type=int,
                        default=1,
                        help="The number of workers to use to write the files")

    # add 1. for huawei yun.
    parser.add_argument("--data_url", type=str, default="", help="s3 url")
    parser.add_argument("--train_url", type=str, default="", help="s3 url")
    parser.add_argument("--init_method", default='', type=str)

    args = parser.parse_args()

    # add 2. for huawei yun.
    if oncloud:
        os.environ['DLS_LOCAL_CACHE_PATH'] = "/cache"
        local_data_dir = os.environ['DLS_LOCAL_CACHE_PATH']
        assert mox.file.exists(local_data_dir)
        logging.info("local disk: " + local_data_dir)
        logging.info("copy data from s3 to local")
        logging.info(mox.file.list_directory(args.data_url, recursive=True))
        mox.file.copy_parallel(args.data_url, local_data_dir)
        logging.info("copy finish...........")

        args.train_corpus = Path(
            os.path.join(local_data_dir, args.train_corpus))
        args.bert_model = os.path.join(local_data_dir, args.bert_model)

        args.train_url = os.path.join(args.train_url, args.output_dir)
        args.output_dir = Path(os.path.join(local_data_dir, args.output_dir))

    if args.num_workers > 1 and args.reduce_memory:
        raise ValueError("Cannot use multiple workers while reducing memory")

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    doc_num = 0
    with DocumentDatabase(reduce_memory=args.reduce_memory) as docs:
        with args.train_corpus.open() as f:
            doc = []
            for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
                line = line.strip()
                if line == "":
                    docs.add_document(doc)
                    doc = []
                    doc_num += 1
                    if doc_num % 100 == 0:
                        logger.info('loaded {} docs!'.format(doc_num))
                else:
                    tokens = tokenizer.tokenize(line)
                    doc.append(tokens)
            if doc:
                docs.add_document(
                    doc
                )  # If the last doc didn't end on a newline, make sure it still gets added
        if len(docs) <= 1:
            exit(
                "ERROR: No document breaks were found in the input file! These are necessary to allow the script to "
                "ensure that random NextSentences are not sampled from the same document. Please add blank lines to "
                "indicate breaks between documents in your input file. If your dataset does not contain multiple "
                "documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, "
                "sections or paragraphs.")

        args.output_dir.mkdir(exist_ok=True)
        file_num = 28

        fouts = []
        for i in range(file_num):
            file_name = os.path.join(
                str(args.output_dir),
                'train_doc_tokens_ngrams_{}.json'.format(i))
            fouts.append(open(file_name, 'w'))

        cnt = 0
        for doc_idx in trange(len(docs), desc="Document"):
            document = docs[doc_idx]
            i = 0
            tokens = []
            while i < len(document):
                segment = document[i]
                if len(tokens) + len(segment) > args.max_seq_len:
                    instance = {"tokens": tokens}

                    file_idx = cnt % file_num
                    fouts[file_idx].write(json.dumps(instance) + '\n')

                    cnt += 1
                    if cnt % 100000 == 0:
                        logger.info('loaded {} examples!'.format(cnt))

                    if cnt <= 10:
                        logger.info('instance: {}'.format(instance))

                    tokens = []
                    tokens += segment
                else:
                    tokens += segment

                i += 1

            if tokens:
                instance = {"tokens": tokens}
                file_idx = cnt % file_num
                fouts[file_idx].write(json.dumps(instance) + '\n')

        for fout in fouts:
            fout.close()

        if oncloud:
            logging.info(
                mox.file.list_directory(str(args.output_dir), recursive=True))
            mox.file.copy_parallel(str(args.output_dir), args.train_url)
Exemple #15
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--job_id",
                        default='tmp',
                        type=str,
                        help='jobid to save training logs')
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        help="The root dir of glue data")
    parser.add_argument("--teacher_model",
                        default='',
                        type=str,
                        help="The teacher model dir.")
    parser.add_argument("--student_model",
                        default='',
                        type=str,
                        help="The student model dir.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        help="The name of the glue task to train.")
    parser.add_argument(
        "--output_dir",
        default='output',
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--max_seq_length",
        default=None,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences longer than this will be truncated, and sequences shorter than this will be padded."
    )

    parser.add_argument("--batch_size",
                        default=None,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--weight_decay',
                        '--wd',
                        default=0.01,
                        type=float,
                        metavar='W',
                        help='weight decay')
    parser.add_argument("--num_train_epochs",
                        default=None,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")

    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--do_eval", action='store_true')
    parser.add_argument('--eval_step', type=int, default=100)

    # distillation params
    parser.add_argument('--aug_train',
                        action='store_true',
                        help="Whether using data augmentation or not")
    parser.add_argument('--kd_type',
                        default='no_kd',
                        choices=['no_kd', 'two_stage', 'logit_kd', 'joint_kd'],
                        help="choose one of the kd type")
    parser.add_argument('--distill_logit',
                        action='store_true',
                        help="Whether using distillation over logits or not")
    parser.add_argument(
        '--distill_rep_attn',
        action='store_true',
        help="Whether using distillation over reps and attns or not")
    parser.add_argument('--temperature', type=float, default=1.)

    # quantization params
    parser.add_argument("--weight_bits",
                        default=32,
                        type=int,
                        help="number of bits for weight")
    parser.add_argument(
        "--weight_quant_method",
        default='twn',
        type=str,
        choices=['twn', 'bwn', 'uniform', 'laq'],
        help="weight quantization methods, can be bwn, twn, laq")
    parser.add_argument("--input_bits",
                        default=32,
                        type=int,
                        help="number of bits for activation")
    parser.add_argument(
        "--input_quant_method",
        default='uniform',
        type=str,
        choices=['uniform', 'lsq'],
        help=
        "weight quantization methods, can be bwn, twn, or symmetric quantization for default"
    )

    parser.add_argument('--learnable_scaling',
                        action='store_true',
                        default=True)
    parser.add_argument(
        "--ACT2FN",
        default='gelu',
        type=str,
        help='activation fn for ffn-mid. A8 uses uq + gelu; A4 uses lsq + relu.'
    )

    # training config
    parser.add_argument(
        '--sym_quant_ffn_attn',
        action='store_true',
        help='whether use sym quant for attn score and ffn after act'
    )  # default asym
    parser.add_argument(
        '--sym_quant_qkvo',
        action='store_true',
        default=True,
        help='whether use asym quant for Q/K/V and others.')  # default sym

    # hypers clipping threshold
    # parser.add_argument('--restore_clip', action='store_true',
    #                     help='if true, restore the last step model from rep/attn kd for two stage kd')
    parser.add_argument('--clip_init_file',
                        default='threshold_std.pkl',
                        help='files to restore init clip values.')
    parser.add_argument(
        '--clip_init_val',
        default=2.5,
        type=float,
        help='init value of clip_vals, default to (-2.5, +2.5).')
    parser.add_argument('--clip_lr',
                        default=1e-4,
                        type=float,
                        help='Use a seperate lr for clip_vals / stepsize')
    parser.add_argument('--clip_wd',
                        default=0.0,
                        type=float,
                        help='weight decay for clip_vals / stepsize')

    # layerwise quantization config
    parser.add_argument('--embed_layerwise',
                        default=False,
                        type=lambda x: bool(int(x)))
    parser.add_argument('--weight_layerwise',
                        default=True,
                        type=lambda x: bool(int(x)))
    parser.add_argument('--input_layerwise',
                        default=True,
                        type=lambda x: bool(int(x)))

    ### spliting
    parser.add_argument(
        '--split',
        action='store_true',
        help=
        'whether to conduct tws spliting. NOTE this is only for training binarybert'
    )
    parser.add_argument('--is_binarybert',
                        action='store_true',
                        help='whether to use binarybert modelling.')

    args = parser.parse_args()
    args.do_lower_case = True

    log_dir = os.path.join(args.output_dir, 'record_%s.log' % args.job_id)
    init_logging(log_dir)

    # Prepare devices
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()
    logging.info("device: {} n_gpu: {}".format(device, n_gpu))

    # Prepare seed
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # Prepare task settings
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    task_name = args.task_name.lower()

    # restore the default setting if they are None
    if args.batch_size is None:
        if task_name in default_params:
            args.batch_size = default_params[task_name]["batch_size"]
            args.batch_size = int(args.batch_size * n_gpu)
    if args.max_seq_length == None:
        if task_name in default_params:
            args.max_seq_length = default_params[task_name]["max_seq_length"]
    if task_name not in processors:
        raise ValueError("Task not found: %s" % task_name)
    print_args(vars(args))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.teacher_model,
                                              do_lower_case=args.do_lower_case)
    config = BertConfig.from_pretrained(args.teacher_model)
    config.num_labels = num_labels

    student_config = copy.deepcopy(config)
    student_config.weight_bits = args.weight_bits
    student_config.input_bits = args.input_bits
    student_config.weight_quant_method = args.weight_quant_method
    student_config.input_quant_method = args.input_quant_method
    student_config.clip_init_val = args.clip_init_val
    student_config.learnable_scaling = args.learnable_scaling
    student_config.sym_quant_qkvo = args.sym_quant_qkvo
    student_config.sym_quant_ffn_attn = args.sym_quant_ffn_attn
    student_config.embed_layerwise = args.embed_layerwise
    student_config.weight_layerwise = args.weight_layerwise
    student_config.input_layerwise = args.input_layerwise
    student_config.hidden_act = args.ACT2FN

    num_train_optimization_steps = 0
    if not args.do_eval:
        if args.aug_train:
            train_examples = processor.get_aug_examples(args.data_dir)
        else:
            train_examples = processor.get_train_examples(args.data_dir)
        if args.gradient_accumulation_steps < 1:
            raise ValueError(
                "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
                .format(args.gradient_accumulation_steps))

        args.batch_size = args.batch_size // args.gradient_accumulation_steps

        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer, output_mode)
        train_data, _ = get_tensor_data(output_mode, train_features)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.batch_size)

        num_train_optimization_steps = int(
            len(train_features) / args.batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs

    eval_examples = processor.get_dev_examples(args.data_dir)
    eval_features = convert_examples_to_features(eval_examples, label_list,
                                                 args.max_seq_length,
                                                 tokenizer, output_mode)
    eval_data, eval_labels = get_tensor_data(output_mode, eval_features)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.batch_size)
    if task_name == "mnli":
        processor = processors["mnli-mm"]()
        if not os.path.exists(args.output_dir + '-MM'):
            os.makedirs(args.output_dir + '-MM')

        mm_eval_examples = processor.get_dev_examples(args.data_dir)
        mm_eval_features = convert_examples_to_features(
            mm_eval_examples, label_list, args.max_seq_length, tokenizer,
            output_mode)
        mm_eval_data, mm_eval_labels = get_tensor_data(output_mode,
                                                       mm_eval_features)

        logging.info("***** Running mm evaluation *****")
        logging.info("  Num examples = %d", len(mm_eval_examples))

        mm_eval_sampler = SequentialSampler(mm_eval_data)
        mm_eval_dataloader = DataLoader(mm_eval_data,
                                        sampler=mm_eval_sampler,
                                        batch_size=args.batch_size)
    else:
        mm_eval_labels = None
        mm_eval_dataloader = None

    if not args.do_eval:  # need the teacher model for training
        teacher_model = BertForSequenceClassification.from_pretrained(
            args.teacher_model, config=config)
        teacher_model.to(device)
        if n_gpu > 1:
            teacher_model = torch.nn.DataParallel(teacher_model)
    else:
        teacher_model = None

    # logging.info("Rename the config and checkpoint to restore if necessary.")
    # if not os.path.isfile(os.path.join(args.student_model, 'config.json')):
    #     os.system('cp -v %s/%s %s/%s' % (args.student_model, 'kd_stage2_config.json', args.student_model, 'config.json'))
    # if not os.path.isfile(os.path.join(args.student_model, 'pytorch_model.bin')):
    #     os.system('cp -v %s/%s %s/%s' % (args.student_model, 'kd_stage2_pytorch_model.bin', args.student_model, 'pytorch_model.bin'))

    if args.split:
        # rename the checkpoint to restore
        split_model_dir = os.path.join(args.output_dir, 'binary_model_init')
        if not os.path.exists(split_model_dir):
            os.mkdir(split_model_dir)
        # copy the json file, avoid over-writing
        source_model_dir = os.path.join(args.student_model, CONFIG_NAME)
        target_model_dir = os.path.join(split_model_dir, CONFIG_NAME)
        os.system('cp -v %s %s' % (source_model_dir, target_model_dir))

        # create the split model ckpt
        source_model_dir = os.path.join(args.student_model, WEIGHTS_NAME)
        target_model_dir = os.path.join(split_model_dir, WEIGHTS_NAME)
        target_model_dir = tws_split(source_model_dir, target_model_dir)
        args.student_model = split_model_dir  # over-write student_model dir
        print(
            "transformed binary model stored at: {}".format(target_model_dir))

    if args.is_binarybert:
        student_model = BertForSequenceClassification_binary.from_pretrained(
            args.student_model, config=student_config)
    else:
        student_model = QuantBertForSequenceClassification.from_pretrained(
            args.student_model, config=student_config)
    student_model.to(device)
    if n_gpu > 1:
        student_model = torch.nn.DataParallel(student_model)

    learner = KDLearner(args, device, student_model, teacher_model,
                        num_train_optimization_steps)

    if args.do_eval:
        """ evaluation """
        learner.evaluate(task_name,
                         eval_dataloader,
                         output_mode,
                         eval_labels,
                         num_labels,
                         eval_examples,
                         mm_eval_dataloader=mm_eval_dataloader,
                         mm_eval_labels=mm_eval_labels)
        return 0
    """ perform training """
    if args.kd_type == 'joint_kd':
        learner.build()
        learner.train(train_examples,
                      task_name,
                      output_mode,
                      eval_labels,
                      num_labels,
                      train_dataloader,
                      eval_dataloader,
                      eval_examples,
                      tokenizer,
                      mm_eval_dataloader=mm_eval_dataloader,
                      mm_eval_labels=mm_eval_labels)

    elif args.kd_type == 'logit_kd':
        # only perform the logits kd
        learner.build(lr=args.learning_rate)
        learner.args.distill_logit = True
        learner.args.distill_rep_attn = False
        learner.train(train_examples,
                      task_name,
                      output_mode,
                      eval_labels,
                      num_labels,
                      train_dataloader,
                      eval_dataloader,
                      eval_examples,
                      tokenizer,
                      mm_eval_dataloader=mm_eval_dataloader,
                      mm_eval_labels=mm_eval_labels)

    elif args.kd_type == 'two_stage':
        # stage 1: intermediate layer distillation
        learner.args.distill_logit = False
        learner.args.distill_rep_attn = True
        learner.build(lr=2.5 * args.learning_rate)
        learner.train(train_examples,
                      task_name,
                      output_mode,
                      eval_labels,
                      num_labels,
                      train_dataloader,
                      eval_dataloader,
                      eval_examples,
                      tokenizer,
                      mm_eval_dataloader=mm_eval_dataloader,
                      mm_eval_labels=mm_eval_labels)

        # stage 2: prediction layer distillation
        learner.student_model.load_state_dict(
            torch.load(os.path.join(learner.output_dir, 'pytorch_model.bin')))
        learner.args.distill_logit = True
        learner.args.distill_rep_attn = False

        learner.build(lr=args.learning_rate)  # prepare the optimizer again.
        learner.train(train_examples,
                      task_name,
                      output_mode,
                      eval_labels,
                      num_labels,
                      train_dataloader,
                      eval_dataloader,
                      eval_examples,
                      tokenizer,
                      mm_eval_dataloader=mm_eval_dataloader,
                      mm_eval_labels=mm_eval_labels)

    else:
        assert args.kd_type == 'no_kd'
        # NO kd training, vanilla cross entropy with hard label
        learner.build(lr=args.learning_rate)  # prepare the optimizer again.
        learner.train(train_examples,
                      task_name,
                      output_mode,
                      eval_labels,
                      num_labels,
                      train_dataloader,
                      eval_dataloader,
                      eval_examples,
                      tokenizer,
                      mm_eval_dataloader=mm_eval_dataloader,
                      mm_eval_labels=mm_eval_labels)

    del learner
    return 0
Exemple #16
0
SEP = '[SEP]'
MASK = '[MASK]'

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout,
                    level=logging.INFO,
                    format=log_format,
                    datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler('debug_layer_loss.log')
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
logger = logging.getLogger()

pretrained_bert_model = f"/rscratch/bohan/ZQBert/zero-shot-qbert/Berts/mrpc_base_l12/"
#pretrained_bert_model = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(pretrained_bert_model)
model = BertForMaskedLM.from_pretrained(pretrained_bert_model)

mask_id = tokenizer.convert_tokens_to_ids([MASK])[0]
sep_id = tokenizer.convert_tokens_to_ids([SEP])[0]
cls_id = tokenizer.convert_tokens_to_ids([CLS])[0]

model.eval()
cuda = torch.cuda.is_available()
if cuda:
    model = model.cuda()


def tokenize_batch(batch):
    return [tokenizer.convert_tokens_to_ids(sent) for sent in batch]