コード例 #1
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--pretrained_bert_model",
        default=None,
        type=str,
        required=True,
        help=
        "Downloaded pretrained model (bert-base-cased/uncased) is under this folder"
    )
    parser.add_argument("--glove_embs",
                        default=None,
                        type=str,
                        required=True,
                        help="Glove word embeddings file")
    parser.add_argument("--glue_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="GLUE data dir")
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help=
        "Task(eg. CoLA, SST-2) that we want to do data augmentation for its train set"
    )
    parser.add_argument("--N",
                        default=30,
                        type=int,
                        help="How many times is the corpus expanded?")
    parser.add_argument(
        "--M",
        default=15,
        type=int,
        help="Choose from M most-likely words in the corresponding position")
    parser.add_argument("--p",
                        default=0.4,
                        type=float,
                        help="Threshold probability p to replace current word")

    args = parser.parse_args()
    # logger.info(args)

    default_params = {
        "CoLA": {
            "N": 30
        },
        "MNLI": {
            "N": 10
        },
        "MRPC": {
            "N": 30
        },
        "SST-2": {
            "N": 20
        },
        "STS-b": {
            "N": 30
        },
        "QQP": {
            "N": 10
        },
        "QNLI": {
            "N": 20
        },
        "RTE": {
            "N": 30
        }
    }

    if args.task_name in default_params:
        args.N = default_params[args.task_name]["N"]

    # Prepare data augmentor
    tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_model)
    model = BertForMaskedLM.from_pretrained(args.pretrained_bert_model)
    model.eval()

    emb_norm, vocab, ids_to_tokens = prepare_embedding_retrieval(
        args.glove_embs)

    data_augmentor = DataAugmentor(model, tokenizer, emb_norm, vocab,
                                   ids_to_tokens, args.M, args.N, args.p)

    # Do data augmentation
    processor = AugmentProcessor(data_augmentor, args.glue_dir, args.task_name)
    processor.read_augment_write()
コード例 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir",
                        default='data',
                        type=str,
                        #required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--model_dir",
                        default='models/finetuned_teacher/',
                        type=str,
                        help="The teacher model dir.")
    parser.add_argument("--tasks",
                        default='RTE,MRPC,STS-B,CoLA,SST-2,QNLI',
                        type=str,
                        #required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default='output',
                        type=str,
                        #required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    parser.add_argument("--max_seq_length",
                        default=None,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")

    parser.add_argument("--do_lower_case",default = True,
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")

    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")

    parser.add_argument("--root_dir", default='./', type=str)
    parser.add_argument("--log_dir", default='', type=str)
    parser.add_argument("--tensorboard_dir", default='', type=str)
    parser.add_argument("--model_save_dir", default='', type=str)

    args = parser.parse_args()
    
    logger.info('The args: {}'.format(args))
    args.data_dir = os.path.join(args.root_dir,args.data_dir)
    args.model_dir = os.path.join(args.root_dir,args.model_dir)
    args.output_dir = os.path.join(args.model_save_dir,args.output_dir)

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mnli-mm": MnliMismatchedProcessor,
        "mrpc": MrpcProcessor,
        "sst-2": Sst2Processor,
        "sts-b": StsbProcessor,
        "qqp": QqpProcessor,
        "qnli": QnliProcessor,
        "rte": RteProcessor,
        "wnli": WnliProcessor
        }

    output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification",
        "wnli": "classification"
        }

    default_params = {
        "cola": {"max_seq_length": 64},
        "mnli": {"max_seq_length": 128},
        "mrpc": {"max_seq_length": 128},
        "sst-2": {"max_seq_length": 64},
        "sts-b": {"max_seq_length": 128},
        "qqp": {"max_seq_length": 128},
        "qnli": {"max_seq_length": 128},
        "rte": {"max_seq_length": 128}
        }
    
    infer_files = {
        "cola": "CoLA.tsv",
        "mnli": "MNLI-m.tsv",
        "mrpc": "MRPC.tsv",
        "sst-2": "SST-2.tsv",
        "sts-b": "STS-B.tsv",
        "qqp": "QQP.tsv",
        "qnli": "QNLI.tsv",
        "rte": "RTE.tsv",
        "wnli": "WNLI.tsv"
        }
    
   

    # Prepare devices
    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()


    logger.info("device: {} n_gpu: {}".format(device, n_gpu))

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tasks = args.tasks.lower()
    for task_name in tasks.split(','):
        data_dir = os.path.join(args.data_dir,task_name)
        model_dir = os.path.join(args.model_dir,task_name)
        if args.max_seq_length == None:
            if task_name in default_params:
                args.max_seq_length = default_params[task_name]["max_seq_length"]

        processor = processors[task_name]()
        output_mode = output_modes[task_name]
        label_list = processor.get_labels()
        num_labels = len(label_list)
        output_file = os.path.join(args.output_dir,infer_files[task_name])

        tokenizer = BertTokenizer.from_pretrained(model_dir, do_lower_case=args.do_lower_case)

        examples = processor.get_test_examples(data_dir)
        features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode)
        data, labels = get_tensor_data(output_mode, features)
        sampler = SequentialSampler(data)
        dataloader = DataLoader(data, sampler=sampler, batch_size=args.batch_size)
        
        model = BertForSequenceClassification.from_pretrained(model_dir, num_labels=num_labels,do_quantize = 0)
        model.to(device)

        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(examples))
        logger.info("  Batch size = %d", args.batch_size)

        model.eval()
        do_infer(model, task_name, dataloader,
                             device, output_mode, output_file,label_list)
        
        if task_name == "mnli":
            processor = processors["mnli-mm"]()
            examples = processor.get_test_examples(data_dir)
            output_file = os.path.join(args.output_dir,'MNLI-mm.tsv')
            features = convert_examples_to_features(
                examples, label_list, args.max_seq_length, tokenizer, output_mode)
            data, labels = get_tensor_data(output_mode, features)

            logger.info("***** Running mm evaluation *****")
            logger.info("  Num examples = %d", len(examples))

            sampler = SequentialSampler(data)
            dataloader = DataLoader(data, sampler=sampler,
                                         batch_size=args.batch_size)
            do_infer(model, task_name, dataloader,
                             device, output_mode, output_file,label_list)
コード例 #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir",
        default='data',
        type=str,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--model_dir",
                        default='models/tinybert',
                        type=str,
                        help="The model dir.")
    parser.add_argument("--teacher_model",
                        default=None,
                        type=str,
                        help="The models directory.")
    parser.add_argument("--student_model",
                        default=None,
                        type=str,
                        help="The models directory.")
    parser.add_argument("--task_name",
                        default='sst-2',
                        type=str,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default='output',
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument('--aug_train',
                        action='store_false',
                        help="Whether to use augmented data or not")
    parser.add_argument('--pred_distill',
                        action='store_true',
                        help="Whether to distil with task layer")
    parser.add_argument('--intermediate_distill',
                        action='store_true',
                        help="Whether to distil with intermediate layers")
    parser.add_argument('--save_fp_model',
                        action='store_true',
                        help="Whether to save fp32 model")
    parser.add_argument('--save_quantized_model',
                        action='store_true',
                        help="Whether to save quantized model")

    parser.add_argument("--weight_bits",
                        default=2,
                        type=int,
                        choices=[2, 8],
                        help="Quantization bits for weight.")
    parser.add_argument("--input_bits",
                        default=8,
                        type=int,
                        help="Quantization bits for activation.")
    parser.add_argument("--clip_val",
                        default=2.5,
                        type=float,
                        help="Initial clip value.")

    args = parser.parse_args()
    assert args.pred_distill or args.intermediate_distill, "'pred_distill' and 'intermediate_distill', at least one must be True"
    summaryWriter = SummaryWriter(args.output_dir)
    logger.info('The args: {}'.format(args))
    task_name = args.task_name.lower()
    data_dir = os.path.join(args.data_dir, task_name)
    output_dir = os.path.join(args.output_dir, task_name)
    # processed_data_dir = os.path.join(args.data_dir,'preprocessed',task_name)

    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    if args.student_model is None:
        args.student_model = os.path.join(args.model_dir, task_name)
    if args.teacher_model is None:
        args.teacher_model = os.path.join(args.model_dir, task_name)

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mnli-mm": MnliMismatchedProcessor,
        "mrpc": MrpcProcessor,
        "sst-2": Sst2Processor,
        "sts-b": StsbProcessor,
        "qqp": QqpProcessor,
        "qnli": QnliProcessor,
        "rte": RteProcessor
    }

    output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification"
    }

    default_params = {
        "cola": {
            "max_seq_length": 64,
            "batch_size": 16,
            "eval_step": 50
        },
        "mnli": {
            "max_seq_length": 128,
            "batch_size": 32,
            "eval_step": 1000
        },
        "mrpc": {
            "max_seq_length": 128,
            "batch_size": 32,
            "eval_step": 200
        },
        "sst-2": {
            "max_seq_length": 64,
            "batch_size": 32,
            "eval_step": 200
        },
        "sts-b": {
            "max_seq_length": 128,
            "batch_size": 32,
            "eval_step": 50
        },
        "qqp": {
            "max_seq_length": 128,
            "batch_size": 32,
            "eval_step": 1000
        },
        "qnli": {
            "max_seq_length": 128,
            "batch_size": 32,
            "eval_step": 1000
        },
        "rte": {
            "max_seq_length": 128,
            "batch_size": 32,
            "eval_step": 100
        }
    }

    acc_tasks = ["mnli", "mrpc", "sst-2", "qqp", "qnli", "rte"]
    corr_tasks = ["sts-b"]
    mcc_tasks = ["cola"]

    # Prepare devices
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    # Prepare seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if task_name in default_params:
        args.batch_size = default_params[task_name]["batch_size"]
        if n_gpu > 0:
            args.batch_size = int(args.batch_size * n_gpu)
        args.max_seq_length = default_params[task_name]["max_seq_length"]
        args.eval_step = default_params[task_name]["eval_step"]

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.student_model,
                                              do_lower_case=True)

    if args.aug_train:
        try:
            train_file = os.path.join(processed_data_dir, 'aug_data')
            train_features = pickle.load(open(train_file, 'rb'))
        except:
            train_examples = processor.get_aug_examples(data_dir)
            train_features = convert_examples_to_features(
                train_examples, label_list, args.max_seq_length, tokenizer,
                output_mode)
    else:
        try:
            train_file = os.path.join(processed_data_dir, 'train_data')
            train_features = pickle.load(open(train_file, 'rb'))
        except:
            train_examples = processor.get_train_examples(data_dir)
            train_features = convert_examples_to_features(
                train_examples, label_list, args.max_seq_length, tokenizer,
                output_mode)

    num_train_optimization_steps = int(
        len(train_features) / args.batch_size) * args.num_train_epochs
    train_data, _ = get_tensor_data(output_mode, train_features)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.batch_size)

    try:
        dev_file = train_file = os.path.join(processed_data_dir, 'dev_data')
        eval_features = pickle.load(open(dev_file, 'rb'))
    except:
        eval_examples = processor.get_dev_examples(data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer, output_mode)

    eval_data, eval_labels = get_tensor_data(output_mode, eval_features)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.batch_size)
    if task_name == "mnli":
        processor = processors["mnli-mm"]()
        try:
            dev_mm_file = train_file = os.path.join(processed_data_dir,
                                                    'dev-mm_data')
            mm_eval_features = pickle.load(open(dev_mm_file, 'rb'))
        except:
            mm_eval_examples = processor.get_dev_examples(data_dir)
            mm_eval_features = convert_examples_to_features(
                mm_eval_examples, label_list, args.max_seq_length, tokenizer,
                output_mode)

        mm_eval_data, mm_eval_labels = get_tensor_data(output_mode,
                                                       mm_eval_features)
        logger.info("  Num examples = %d", len(mm_eval_features))

        mm_eval_sampler = SequentialSampler(mm_eval_data)
        mm_eval_dataloader = DataLoader(mm_eval_data,
                                        sampler=mm_eval_sampler,
                                        batch_size=args.batch_size)

    teacher_model = BertForSequenceClassification.from_pretrained(
        args.teacher_model)
    teacher_model.to(device)
    teacher_model.eval()
    if n_gpu > 1:
        teacher_model = torch.nn.DataParallel(teacher_model)

    result = do_eval(teacher_model, task_name, eval_dataloader, device,
                     output_mode, eval_labels, num_labels)
    if task_name in acc_tasks:
        if task_name in ['sst-2', 'mnli', 'qnli', 'rte']:
            fp32_performance = f"acc:{result['acc']}"
        elif task_name in ['mrpc', 'qqp']:
            fp32_performance = f"f1/acc:{result['f1']}/{result['acc']}"
    if task_name in corr_tasks:
        fp32_performance = f"pearson/spearmanr:{result['pearson']}/{result['spearmanr']}"

    if task_name in mcc_tasks:
        fp32_performance = f"mcc:{result['mcc']}"

    if task_name == "mnli":
        result = do_eval(teacher_model, 'mnli-mm', mm_eval_dataloader, device,
                         output_mode, mm_eval_labels, num_labels)
        fp32_performance += f"  mm-acc:{result['acc']}"
    fp32_performance = task_name + ' fp32   ' + fp32_performance
    student_config = BertConfig.from_pretrained(args.teacher_model,
                                                quantize_act=True,
                                                weight_bits=args.weight_bits,
                                                input_bits=args.input_bits,
                                                clip_val=args.clip_val)
    student_model = QuantBertForSequenceClassification.from_pretrained(
        args.student_model, config=student_config, num_labels=num_labels)
    student_model.to(device)

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_features))
    logger.info("  Batch size = %d", args.batch_size)
    logger.info("  Num steps = %d", num_train_optimization_steps)
    if n_gpu > 1:
        student_model = torch.nn.DataParallel(student_model)

    # Prepare optimizer
    param_optimizer = list(student_model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    schedule = 'warmup_linear'
    optimizer = BertAdam(optimizer_grouped_parameters,
                         schedule=schedule,
                         lr=args.learning_rate,
                         warmup=0.1,
                         t_total=num_train_optimization_steps)
    loss_mse = MSELoss()
    global_step = 0
    best_dev_acc = 0.0
    previous_best = None

    tr_loss = 0.
    tr_att_loss = 0.
    tr_rep_loss = 0.
    tr_cls_loss = 0.
    for epoch_ in range(int(args.num_train_epochs)):
        nb_tr_examples, nb_tr_steps = 0, 0

        for step, batch in enumerate(train_dataloader):
            student_model.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch
            att_loss = 0.
            rep_loss = 0.
            cls_loss = 0.
            loss = 0.

            student_logits, student_atts, student_reps = student_model(
                input_ids, segment_ids, input_mask)

            with torch.no_grad():
                teacher_logits, teacher_atts, teacher_reps = teacher_model(
                    input_ids, segment_ids, input_mask)

            if args.pred_distill:
                if output_mode == "classification":
                    cls_loss = soft_cross_entropy(student_logits,
                                                  teacher_logits)
                elif output_mode == "regression":
                    cls_loss = loss_mse(student_logits, teacher_logits)

                loss = cls_loss
                tr_cls_loss += cls_loss.item()

            if args.intermediate_distill:
                for student_att, teacher_att in zip(student_atts,
                                                    teacher_atts):
                    student_att = torch.where(
                        student_att <= -1e2,
                        torch.zeros_like(student_att).to(device), student_att)
                    teacher_att = torch.where(
                        teacher_att <= -1e2,
                        torch.zeros_like(teacher_att).to(device), teacher_att)
                    tmp_loss = loss_mse(student_att, teacher_att)
                    att_loss += tmp_loss

                for student_rep, teacher_rep in zip(student_reps,
                                                    teacher_reps):
                    tmp_loss = loss_mse(student_rep, teacher_rep)
                    rep_loss += tmp_loss

                loss += rep_loss + att_loss
                tr_att_loss += att_loss.item()
                tr_rep_loss += rep_loss.item()

            if n_gpu > 1:
                loss = loss.mean()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1

            tr_loss += loss.item()
            nb_tr_examples += label_ids.size(0)
            nb_tr_steps += 1
            if global_step % args.eval_step == 0 or global_step == num_train_optimization_steps - 1:
                logger.info("***** Running evaluation *****")
                logger.info("  {} step of {} steps".format(
                    global_step, num_train_optimization_steps))
                if previous_best is not None:
                    logger.info(
                        f"{fp32_performance}\nPrevious best = {previous_best}")

                student_model.eval()

                loss = tr_loss / (step + 1)
                cls_loss = tr_cls_loss / (step + 1)
                att_loss = tr_att_loss / (step + 1)
                rep_loss = tr_rep_loss / (step + 1)

                result = do_eval(student_model, task_name, eval_dataloader,
                                 device, output_mode, eval_labels, num_labels)
                result['global_step'] = global_step
                result['cls_loss'] = cls_loss
                result['att_loss'] = att_loss
                result['rep_loss'] = rep_loss
                result['loss'] = loss
                summaryWriter.add_scalar('total_loss', loss, global_step)
                summaryWriter.add_scalars(
                    'distill_loss', {
                        'att_loss': att_loss,
                        'rep_loss': rep_loss,
                        'cls_loss': cls_loss
                    }, global_step)

                if task_name == 'cola':
                    summaryWriter.add_scalar('mcc', result['mcc'], global_step)
                elif task_name in [
                        'sst-2', 'mnli', 'mnli-mm', 'qnli', 'rte', 'wnli'
                ]:
                    summaryWriter.add_scalar('acc', result['acc'], global_step)
                elif task_name in ['mrpc', 'qqp']:
                    summaryWriter.add_scalars(
                        'performance', {
                            'acc': result['acc'],
                            'f1': result['f1'],
                            'acc_and_f1': result['acc_and_f1']
                        }, global_step)
                else:
                    summaryWriter.add_scalar('corr', result['corr'],
                                             global_step)

                save_model = False

                if task_name in acc_tasks and result['acc'] > best_dev_acc:
                    if task_name in ['sst-2', 'mnli', 'qnli', 'rte']:
                        previous_best = f"acc:{result['acc']}"
                    elif task_name in ['mrpc', 'qqp']:
                        previous_best = f"f1/acc:{result['f1']}/{result['acc']}"
                    best_dev_acc = result['acc']
                    save_model = True

                if task_name in corr_tasks and result['corr'] > best_dev_acc:
                    previous_best = f"pearson/spearmanr:{result['pearson']}/{result['spearmanr']}"
                    best_dev_acc = result['corr']
                    save_model = True

                if task_name in mcc_tasks and result['mcc'] > best_dev_acc:
                    previous_best = f"mcc:{result['mcc']}"
                    best_dev_acc = result['mcc']
                    save_model = True

                if save_model:
                    # Test mnli-mm
                    if task_name == "mnli":
                        result = do_eval(student_model, 'mnli-mm',
                                         mm_eval_dataloader, device,
                                         output_mode, mm_eval_labels,
                                         num_labels)
                        previous_best += f"mm-acc:{result['acc']}"
                    logger.info(fp32_performance)
                    logger.info(previous_best)
                    if args.save_fp_model:
                        logger.info(
                            "******************** Save full precision model ********************"
                        )
                        model_to_save = student_model.module if hasattr(
                            student_model, 'module') else student_model
                        output_model_file = os.path.join(
                            output_dir, WEIGHTS_NAME)
                        output_config_file = os.path.join(
                            output_dir, CONFIG_NAME)

                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        model_to_save.config.to_json_file(output_config_file)
                        tokenizer.save_vocabulary(output_dir)
                    if args.save_quantized_model:
                        logger.info(
                            "******************** Save quantized model ********************"
                        )
                        output_quant_dir = os.path.join(output_dir, 'quant')
                        if not os.path.exists(output_quant_dir):
                            os.makedirs(output_quant_dir)
                        model_to_save = student_model.module if hasattr(
                            student_model, 'module') else student_model
                        quant_model = copy.deepcopy(model_to_save)
                        for name, module in quant_model.named_modules():
                            if hasattr(module, 'weight_quantizer'):
                                module.weight.data = module.weight_quantizer.apply(
                                    module.weight, module.weight_clip_val,
                                    module.weight_bits, True)

                        output_model_file = os.path.join(
                            output_quant_dir, WEIGHTS_NAME)
                        output_config_file = os.path.join(
                            output_quant_dir, CONFIG_NAME)

                        torch.save(quant_model.state_dict(), output_model_file)
                        model_to_save.config.to_json_file(output_config_file)
                        tokenizer.save_vocabulary(output_quant_dir)
コード例 #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir",
                        default='data/',
                        type=str,
                        help="The data directory.")
    parser.add_argument("--model_dir",
                        default='models/',
                        type=str,
                        help="The models directory.")
    parser.add_argument("--teacher_model",
                        default=None,
                        type=str,
                        help="The models directory.")
    parser.add_argument("--student_model",
                        default=None,
                        type=str,
                        help="The models directory.")
    parser.add_argument(
        "--output_dir",
        default='output',
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument('--version_2_with_negative',
                        action='store_true',
                        help="Squadv2.0 if true else Squadv1.1 ")

    # default
    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json "
        "output file.")
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument("--verbose_logging", default=0, type=int)
    parser.add_argument(
        '--null_score_diff_threshold',
        type=float,
        default=0.0,
        help=
        "If null_score - best_non_null is greater than the threshold predict null."
    )
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--do_lower_case',
        #action='store_true',
        default=True,
        help="do lower case")

    parser.add_argument("--per_gpu_batch_size",
                        default=16,
                        type=int,
                        help="Per GPU batch size for training.")
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument('--eval_step',
                        type=int,
                        default=200,
                        help="Evaluate every X training steps")

    parser.add_argument('--pred_distill',
                        action='store_true',
                        help="Whether to distil with task layer")
    parser.add_argument('--intermediate_distill',
                        action='store_true',
                        help="Whether to distil with intermediate layers")
    parser.add_argument('--save_fp_model',
                        action='store_true',
                        help="Whether to save fp32 model")
    parser.add_argument('--save_quantized_model',
                        action='store_true',
                        help="Whether to save quantized model")

    parser.add_argument("--weight_bits",
                        default=2,
                        type=int,
                        choices=[2, 8],
                        help="Quantization bits for weight.")
    parser.add_argument("--input_bits",
                        default=8,
                        type=int,
                        help="Quantization bits for activation.")
    parser.add_argument("--clip_val",
                        default=2.5,
                        type=float,
                        help="Initial clip value.")

    args = parser.parse_args()
    summaryWriter = SummaryWriter(args.output_dir)

    if args.teacher_model is None:
        args.teacher_model = args.model_dir
    if args.student_model is None:
        args.student_model = args.model_dir

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.n_gpu = torch.cuda.device_count()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    args.batch_size = args.n_gpu * args.per_gpu_batch_size

    logger.info(f'The args: {args}')
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.teacher_model,
                                              do_lower_case=True)

    # preparing training data
    input_file = 'train-v2.0' if args.version_2_with_negative else 'train-v1.1'
    input_file = os.path.join(args.data_dir, input_file)
    if os.path.exists(input_file):
        train_features = pickle.load(open(input_file, 'rb'))
    else:
        input_file = 'train-v2.0.json' if args.version_2_with_negative else 'train-v1.1.json'
        input_file = os.path.join(args.data_dir, input_file)
        _, train_examples = read_squad_examples(
            input_file=input_file,
            is_training=True,
            version_2_with_negative=args.version_2_with_negative)
        train_features = convert_examples_to_features(
            examples=train_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=True)

    num_train_optimization_steps = int(
        len(train_features) / args.batch_size) * args.num_train_epochs
    logger.info("***** Running training *****")
    logger.info("  Num split examples = %d", len(train_features))
    logger.info("  Batch size = %d", args.batch_size)
    logger.info("  Num steps = %d", num_train_optimization_steps)
    all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                   dtype=torch.long)
    all_start_positions = torch.tensor(
        [f.start_position for f in train_features], dtype=torch.long)
    all_end_positions = torch.tensor([f.end_position for f in train_features],
                                     dtype=torch.long)
    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                               all_start_positions, all_end_positions)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.batch_size)

    input_file = 'dev-v2.0.json' if args.version_2_with_negative else 'dev-v1.1.json'
    args.dev_file = os.path.join(args.data_dir, input_file)
    dev_dataset, eval_examples = read_squad_examples(
        input_file=args.dev_file,
        is_training=False,
        version_2_with_negative=args.version_2_with_negative)
    eval_features = convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        doc_stride=args.doc_stride,
        max_query_length=args.max_query_length,
        is_training=False)
    logger.info("***** Running predictions *****")
    logger.info("  Num orig examples = %d", len(eval_examples))
    logger.info("  Num split examples = %d", len(eval_features))
    logger.info("  Batch size = %d", args.batch_size)
    all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                   dtype=torch.long)
    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                              all_example_index)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.batch_size)

    teacher_model = BertForQuestionAnswering.from_pretrained(
        args.teacher_model)
    teacher_model.to(args.device)
    teacher_model.eval()
    if args.n_gpu > 1:
        teacher_model = torch.nn.DataParallel(teacher_model)
    result = do_eval(args, teacher_model, eval_dataloader, eval_features,
                     eval_examples, args.device, dev_dataset)
    em, f1 = result['exact_match'], result['f1']
    logger.info(f"Full precision teacher exact_match={em},f1={f1}")

    student_config = BertConfig.from_pretrained(args.student_model,
                                                quantize_act=True,
                                                weight_bits=args.weight_bits,
                                                input_bits=args.input_bits,
                                                clip_val=args.clip_val)
    student_model = QuantBertForQuestionAnswering.from_pretrained(
        args.student_model, config=student_config)
    student_model.to(args.device)

    if args.n_gpu > 1:
        student_model = torch.nn.DataParallel(student_model)
    # Prepare optimizer
    param_optimizer = list(student_model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    schedule = 'warmup_linear'
    optimizer = BertAdam(optimizer_grouped_parameters,
                         schedule=schedule,
                         lr=args.learning_rate,
                         warmup=0.1,
                         t_total=num_train_optimization_steps)

    loss_mse = MSELoss()
    # Train and evaluate
    global_step = 0
    best_dev_f1 = 0.0
    flag_loss = float('inf')
    previous_best = None
    tr_loss = 0.
    tr_att_loss = 0.
    tr_rep_loss = 0.
    tr_cls_loss = 0.
    for epoch_ in range(int(args.num_train_epochs)):
        for step, batch in enumerate(train_dataloader):
            student_model.train()
            batch = tuple(t.to(args.device) for t in batch)

            input_ids, input_mask, segment_ids, start_positions, end_positions = batch
            att_loss = 0.
            rep_loss = 0.
            cls_loss = 0.
            loss = 0

            student_logits, student_atts, student_reps = student_model(
                input_ids, segment_ids, input_mask)
            with torch.no_grad():
                teacher_logits, teacher_atts, teacher_reps = teacher_model(
                    input_ids, segment_ids, input_mask)

            if args.pred_distill:
                soft_start_ce_loss = soft_cross_entropy(
                    student_logits[0], teacher_logits[0])
                soft_end_ce_loss = soft_cross_entropy(student_logits[1],
                                                      teacher_logits[1])
                cls_loss = soft_start_ce_loss + soft_end_ce_loss
                loss += cls_loss
                tr_cls_loss += cls_loss.item()

            if args.intermediate_distill:
                for student_att, teacher_att in zip(student_atts,
                                                    teacher_atts):
                    student_att = torch.where(
                        student_att <= -1e2,
                        torch.zeros_like(student_att).to(args.device),
                        student_att)
                    teacher_att = torch.where(
                        teacher_att <= -1e2,
                        torch.zeros_like(teacher_att).to(args.device),
                        teacher_att)
                    tmp_loss = loss_mse(student_att, teacher_att)
                    att_loss += tmp_loss

                for student_rep, teacher_rep in zip(student_reps,
                                                    teacher_reps):
                    tmp_loss = loss_mse(student_rep, teacher_rep)
                    rep_loss += tmp_loss

                loss += rep_loss + att_loss
                tr_att_loss += att_loss.item()
                tr_rep_loss += rep_loss.item()

            if args.n_gpu > 1:
                loss = loss.mean()

            loss.backward()
            tr_loss += loss.item()
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1

            save_model = False
            if global_step % args.eval_step == 0 or global_step == num_train_optimization_steps - 1:
                logger.info("***** Running evaluation *****")
                logger.info(f"  Epoch = {epoch_} iter {global_step} step")
                if previous_best is not None:
                    logger.info(f"Previous best = {previous_best}")

                student_model.eval()
                result = do_eval(args, student_model, eval_dataloader,
                                 eval_features, eval_examples, args.device,
                                 dev_dataset)
                em, f1 = result['exact_match'], result['f1']
                logger.info(f'{em}/{f1}')
                if f1 > best_dev_f1:
                    previous_best = f"exact_match={em},f1={f1}"
                    best_dev_f1 = f1
                    save_model = True

                summaryWriter.add_scalars('performance', {
                    'exact_match': em,
                    'f1': f1
                }, global_step)
                loss = tr_loss / global_step
                cls_loss = tr_cls_loss / global_step
                att_loss = tr_att_loss / global_step
                rep_loss = tr_rep_loss / global_step

                summaryWriter.add_scalar('total_loss', loss, global_step)
                summaryWriter.add_scalars(
                    'distill_loss', {
                        'att_loss': att_loss,
                        'rep_loss': rep_loss,
                        'cls_loss': cls_loss
                    }, global_step)

            #save quantiozed model
            if save_model:
                logger.info(previous_best)
                if args.save_fp_model:
                    logger.info(
                        "******************** Save full precision model ********************"
                    )
                    model_to_save = student_model.module if hasattr(
                        student_model, 'module') else student_model
                    output_model_file = os.path.join(args.output_dir,
                                                     WEIGHTS_NAME)
                    output_config_file = os.path.join(args.output_dir,
                                                      CONFIG_NAME)

                    torch.save(model_to_save.state_dict(), output_model_file)
                    model_to_save.config.to_json_file(output_config_file)
                    tokenizer.save_vocabulary(args.output_dir)
                if args.save_quantized_model:
                    logger.info(
                        "******************** Save quantized model ********************"
                    )
                    output_quant_dir = os.path.join(args.output_dir, 'quant')
                    if not os.path.exists(output_quant_dir):
                        os.makedirs(output_quant_dir)
                    model_to_save = student_model.module if hasattr(
                        student_model, 'module') else student_model
                    quant_model = copy.deepcopy(model_to_save)
                    for name, module in quant_model.named_modules():
                        if hasattr(module, 'weight_quantizer'):
                            module.weight.data = module.weight_quantizer.apply(
                                module.weight, module.weight_clip_val,
                                module.weight_bits, True)

                    output_model_file = os.path.join(output_quant_dir,
                                                     WEIGHTS_NAME)
                    output_config_file = os.path.join(output_quant_dir,
                                                      CONFIG_NAME)

                    torch.save(quant_model.state_dict(), output_model_file)
                    model_to_save.config.to_json_file(output_config_file)
                    tokenizer.save_vocabulary(output_quant_dir)