示例#1
0
class Runner():
    ''' Handler for complete pre-training progress of upstream models '''
    def __init__(self, args, config, dataloader, ckpdir):

        self.device = torch.device('cuda') if (
            args.gpu and torch.cuda.is_available()) else torch.device('cpu')
        if torch.cuda.is_available(): print('[Runner] - CUDA is available!')
        self.model_kept = []
        self.global_step = 1
        self.log = SummaryWriter(ckpdir)

        self.args = args
        self.config = config
        self.dataloader = dataloader
        self.ckpdir = ckpdir

        # optimizer
        self.learning_rate = float(config['optimizer']['learning_rate'])
        self.warmup_proportion = config['optimizer']['warmup_proportion']
        self.gradient_accumulation_steps = config['optimizer'][
            'gradient_accumulation_steps']
        self.gradient_clipping = config['optimizer']['gradient_clipping']

        # Training details
        self.apex = config['runner']['apex']
        self.total_steps = config['runner']['total_steps']
        self.log_step = config['runner']['log_step']
        self.save_step = config['runner']['save_step']
        self.duo_feature = config['runner']['duo_feature']
        self.max_keep = config['runner']['max_keep']

        # model
        self.transformer_config = config['transformer']
        self.input_dim = self.transformer_config['input_dim']
        self.output_dim = 1025 if self.duo_feature else None  # output dim is the same as input dim if not using duo features

    def set_model(self):
        print('[Runner] - Initializing Transformer model...')

        # build the Transformer model with speech prediction head
        model_config = TransformerConfig(self.config)
        self.dr = model_config.downsample_rate
        self.hidden_size = model_config.hidden_size

        self.model = TransformerForMaskedAcousticModel(
            model_config, self.input_dim, self.output_dim).to(self.device)
        self.model.train()

        if self.args.multi_gpu:
            self.model = torch.nn.DataParallel(self.model)
            print('[Runner] - Multi-GPU training Enabled: ' +
                  str(torch.cuda.device_count()))
        print('[Runner] - Number of parameters: ' + str(
            sum(p.numel()
                for p in self.model.parameters() if p.requires_grad)))

        # Setup optimizer
        param_optimizer = list(self.model.named_parameters())

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        if self.apex:
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=self.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if self.config['optimizer']['loss_scale'] == 0:
                self.optimizer = FP16_Optimizer(optimizer,
                                                dynamic_loss_scale=True)
            else:
                self.optimizer = FP16_Optimizer(
                    optimizer,
                    static_loss_scale=self.config['optimizer']['loss_scale'])
            self.warmup_linear = WarmupLinearSchedule(
                warmup=self.warmup_proportion, t_total=self.total_steps)
        else:
            self.optimizer = BertAdam(optimizer_grouped_parameters,
                                      lr=self.learning_rate,
                                      warmup=self.warmup_proportion,
                                      t_total=self.total_steps)

    def save_model(self, name='states', to_path=None):
        all_states = {
            'SpecHead':
            self.model.SpecHead.state_dict() if not self.args.multi_gpu else
            self.model.module.SpecHead.state_dict(),
            'Transformer':
            self.model.Transformer.state_dict() if not self.args.multi_gpu else
            self.model.module.Transformer.state_dict(),
            'Optimizer':
            self.optimizer.state_dict(),
            'Global_step':
            self.global_step,
            'Settings': {
                'Config': self.config,
                'Paras': self.args,
            },
        }

        if to_path is None:
            new_model_path = '{}/{}-{}.ckpt'.format(self.ckpdir, name,
                                                    self.global_step)
        else:
            new_model_path = to_path

        torch.save(all_states, new_model_path)
        self.model_kept.append(new_model_path)

        if len(self.model_kept) >= self.max_keep:
            os.remove(self.model_kept[0])
            self.model_kept.pop(0)

    def up_sample_frames(self, spec, return_first=False):
        if len(spec.shape) != 3:
            spec = spec.unsqueeze(0)
            assert (len(spec.shape) == 3
                    ), 'Input should have acoustic feature of shape BxTxD'
        # spec shape: [batch_size, sequence_length // downsample_rate, output_dim * downsample_rate]
        spec_flatten = spec.view(spec.shape[0], spec.shape[1] * self.dr,
                                 spec.shape[2] // self.dr)
        if return_first: return spec_flatten[0]
        return spec_flatten  # spec_flatten shape: [batch_size, sequence_length * downsample_rate, output_dim // downsample_rate]

    def down_sample_frames(self, spec):
        left_over = spec.shape[1] % self.dr
        if left_over != 0: spec = spec[:, :-left_over, :]
        spec_stacked = spec.view(spec.shape[0], spec.shape[1] // self.dr,
                                 spec.shape[2] * self.dr)
        return spec_stacked

    def process_data(self, spec):
        """Process training data for the masked acoustic model"""
        with torch.no_grad():

            assert (
                len(spec) == 5
            ), 'dataloader should return (spec_masked, pos_enc, mask_label, attn_mask, spec_stacked)'
            # Unpack and Hack bucket: Bucketing should cause acoustic feature to have shape 1xBxTxD'
            spec_masked = spec[0].squeeze(0)
            pos_enc = spec[1].squeeze(0)
            mask_label = spec[2].squeeze(0)
            attn_mask = spec[3].squeeze(0)
            spec_stacked = spec[4].squeeze(0)

            spec_masked = spec_masked.to(device=self.device)
            if pos_enc.dim() == 3:
                # pos_enc: (batch_size, seq_len, hidden_size)
                # GPU memory need (batch_size * seq_len * hidden_size)
                pos_enc = torch.FloatTensor(pos_enc).to(device=self.device)
            elif pos_enc.dim() == 2:
                # pos_enc: (seq_len, hidden_size)
                # GPU memory only need (seq_len * hidden_size) even after expanded
                pos_enc = torch.FloatTensor(pos_enc).to(
                    device=self.device).expand(spec_masked.size(0),
                                               *pos_enc.size())
            mask_label = torch.ByteTensor(mask_label).to(device=self.device)
            attn_mask = torch.FloatTensor(attn_mask).to(device=self.device)
            spec_stacked = spec_stacked.to(device=self.device)

        return spec_masked, pos_enc, mask_label, attn_mask, spec_stacked  # (x, pos_enc, mask_label, attention_mask. y)

    def train(self):
        ''' Self-Supervised Pre-Training of Transformer Model'''

        pbar = tqdm(total=self.total_steps)
        while self.global_step <= self.total_steps:

            progress = tqdm(self.dataloader, desc="Iteration")

            step = 0
            loss_val = 0
            for batch_is_valid, *batch in progress:
                try:
                    if self.global_step > self.total_steps: break
                    if not batch_is_valid: continue
                    step += 1

                    spec_masked, pos_enc, mask_label, attn_mask, spec_stacked = self.process_data(
                        batch)
                    loss, pred_spec = self.model(spec_masked, pos_enc,
                                                 mask_label, attn_mask,
                                                 spec_stacked)

                    # Accumulate Loss
                    if self.gradient_accumulation_steps > 1:
                        loss = loss / self.gradient_accumulation_steps
                    if self.apex and self.args.multi_gpu:
                        raise NotImplementedError
                    elif self.apex:
                        self.optimizer.backward(loss)
                    elif self.args.multi_gpu:
                        loss = loss.sum()
                        loss.backward()
                    else:
                        loss.backward()
                    loss_val += loss.item()

                    # Update
                    if (step + 1) % self.gradient_accumulation_steps == 0:
                        if self.apex:
                            # modify learning rate with special warm up BERT uses
                            # if conifg.apex is False, BertAdam is used and handles this automatically
                            lr_this_step = self.learning_rate * self.warmup_linear.get_lr(
                                self.global_step, self.warmup_proportion)
                            for param_group in self.optimizer.param_groups:
                                param_group['lr'] = lr_this_step

                        # Step
                        grad_norm = torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.gradient_clipping)
                        if math.isnan(grad_norm):
                            print(
                                '[Runner] - Error : grad norm is NaN @ step ' +
                                str(self.global_step))
                        else:
                            self.optimizer.step()
                        self.optimizer.zero_grad()

                        if self.global_step % self.log_step == 0:
                            # Log
                            self.log.add_scalar('lr',
                                                self.optimizer.get_lr()[0],
                                                self.global_step)
                            self.log.add_scalar('loss', (loss_val),
                                                self.global_step)
                            self.log.add_scalar('gradient norm', grad_norm,
                                                self.global_step)
                            progress.set_description("Loss %.4f" % (loss_val))

                        if self.global_step % self.save_step == 0:
                            self.save_model('states')
                            mask_spec = self.up_sample_frames(
                                spec_masked[0], return_first=True)
                            pred_spec = self.up_sample_frames(
                                pred_spec[0], return_first=True)
                            true_spec = self.up_sample_frames(
                                spec_stacked[0], return_first=True)
                            mask_spec = plot_spectrogram_to_numpy(
                                mask_spec.data.cpu().numpy())
                            pred_spec = plot_spectrogram_to_numpy(
                                pred_spec.data.cpu().numpy())
                            true_spec = plot_spectrogram_to_numpy(
                                true_spec.data.cpu().numpy())
                            self.log.add_image('mask_spec', mask_spec,
                                               self.global_step)
                            self.log.add_image('pred_spec', pred_spec,
                                               self.global_step)
                            self.log.add_image('true_spec', true_spec,
                                               self.global_step)

                        loss_val = 0
                        pbar.update(1)
                        self.global_step += 1

                except RuntimeError as e:
                    if 'CUDA out of memory' in str(e):
                        print('CUDA out of memory at step: ', self.global_step)
                        torch.cuda.empty_cache()
                        self.optimizer.zero_grad()
                    else:
                        raise

        pbar.close()
        self.log.close()
示例#2
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument('--pregenerated_data', type=Path, required=True)
    parser.add_argument('--teacher_model',
                        default=None,
                        type=str,
                        required=True)
    parser.add_argument('--student_model',
                        default=None,
                        type=str,
                        required=True)
    parser.add_argument('--output_dir', default=None, type=str, required=True)

    # Other parameters
    parser.add_argument(
        '--max_seq_length',
        default=128,
        type=int,
        help=
        'The maximum total input sequence length after WordPiece tokenization. \n'
        'Sequences longer than this will be truncated, and sequences shorter \n'
        'than this will be padded.',
    )

    parser.add_argument(
        '--reduce_memory',
        action='store_true',
        help=
        'Store training data as on-disc memmaps to massively reduce memory usage',
    )
    parser.add_argument(
        '--do_eval',
        action='store_true',
        help='Whether to run eval on the dev set.',
    )
    parser.add_argument(
        '--do_lower_case',
        action='store_true',
        help='Set this flag if you are using an uncased model.',
    )
    parser.add_argument(
        '--train_batch_size',
        default=32,
        type=int,
        help='Total batch size for training.',
    )
    parser.add_argument(
        '--eval_batch_size',
        default=8,
        type=int,
        help='Total batch size for eval.',
    )
    parser.add_argument(
        '--learning_rate',
        default=5e-5,
        type=float,
        help='The initial learning rate for Adam.',
    )
    parser.add_argument(
        '--weight_decay',
        '--wd',
        default=1e-4,
        type=float,
        metavar='W',
        help='weight decay',
    )
    parser.add_argument(
        '--num_train_epochs',
        default=3.0,
        type=float,
        help='Total number of training epochs to perform.',
    )
    parser.add_argument(
        '--warmup_proportion',
        default=0.1,
        type=float,
        help=
        'Proportion of training to perform linear learning rate warmup for. '
        'E.g., 0.1 = 10%% of training.',
    )
    parser.add_argument(
        '--no_cuda',
        action='store_true',
        help='Whether not to use CUDA when available',
    )
    parser.add_argument(
        '--local_rank',
        type=int,
        default=-1,
        help='local_rank for distributed training on gpus',
    )
    parser.add_argument(
        '--seed',
        type=int,
        default=42,
        help='random seed for initialization',
    )
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        'Number of updates steps to accumulate before performing a backward/update pass.',
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help='Whether to use 16-bit float precision instead of 32-bit',
    )
    parser.add_argument(
        '--continue_train',
        action='store_true',
        help='Whether to train from checkpoints',
    )

    # Additional arguments
    parser.add_argument('--eval_step', type=int, default=1000)

    # This is used for running on Huawei Cloud.
    parser.add_argument('--data_url', type=str, default='')

    args = parser.parse_args()
    logger.info('args:{}'.format(args))

    samples_per_epoch = []
    for i in range(int(args.num_train_epochs)):
        epoch_file = args.pregenerated_data / 'epoch_{}.json'.format(i)
        metrics_file = args.pregenerated_data / 'epoch_{}_metrics.json'.format(
            i)
        if epoch_file.is_file() and metrics_file.is_file():
            metrics = json.loads(metrics_file.read_text())
            samples_per_epoch.append(metrics['num_training_examples'])
        else:
            if i == 0:
                exit('No training data was found!')
            print(
                'Warning! There are fewer epochs of pregenerated data ({}) than training epochs ({}).'
                .format(i, args.num_train_epochs))
            print(
                'This script will loop over the available data, but training diversity may be negatively impacted.'
            )
            num_data_epochs = i
            break
    else:
        num_data_epochs = args.num_train_epochs

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device('cuda' if torch.cuda.is_available()
                              and not args.no_cuda else 'cpu')
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device('cuda', args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )

    logger.info(
        'device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}'.
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            'Invalid gradient_accumulation_steps parameter: {}, should be >= 1'
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = (args.train_batch_size //
                             args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            'Output directory ({}) already exists and is not empty.'.format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    total_train_examples = 0
    for i in range(int(args.num_train_epochs)):
        # The modulo takes into account the fact that we may loop over limited epochs of data
        total_train_examples += samples_per_epoch[i % len(samples_per_epoch)]

    num_train_optimization_steps = int(total_train_examples /
                                       args.train_batch_size /
                                       args.gradient_accumulation_steps)
    if args.local_rank != -1:
        num_train_optimization_steps = (num_train_optimization_steps //
                                        torch.distributed.get_world_size())

    if args.continue_train:
        student_model = TinyBertForPreTraining.from_pretrained(
            args.student_model)
    else:
        student_model = TinyBertForPreTraining.from_scratch(args.student_model)
    teacher_model = BertModel.from_pretrained(args.teacher_model)

    # student_model = TinyBertForPreTraining.from_scratch(args.student_model, fit_size=teacher_model.config.hidden_size)
    student_model.to(device)
    teacher_model.to(device)

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                'Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.'
            )

        teacher_model = DDP(teacher_model)
    elif n_gpu > 1:
        student_model = torch.nn.DataParallel(student_model)
        teacher_model = torch.nn.DataParallel(teacher_model)

    size = 0
    for n, p in student_model.named_parameters():
        logger.info('n: {}'.format(n))
        logger.info('p: {}'.format(p.nelement()))
        size += p.nelement()

    logger.info('Total parameters: {}'.format(size))

    # Prepare optimizer
    param_optimizer = list(student_model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01,
        },
        {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0,
        },
    ]

    loss_mse = MSELoss()
    optimizer = BertAdam(
        optimizer_grouped_parameters,
        lr=args.learning_rate,
        warmup=args.warmup_proportion,
        t_total=num_train_optimization_steps,
    )

    global_step = 0
    logging.info('***** Running training *****')
    logging.info('  Num examples = {}'.format(total_train_examples))
    logging.info('  Batch size = %d', args.train_batch_size)
    logging.info('  Num steps = %d', num_train_optimization_steps)

    for epoch in trange(int(args.num_train_epochs), desc='Epoch'):
        epoch_dataset = PregeneratedDataset(
            epoch=epoch,
            training_path=args.pregenerated_data,
            tokenizer=tokenizer,
            num_data_epochs=num_data_epochs,
            reduce_memory=args.reduce_memory,
        )
        if args.local_rank == -1:
            train_sampler = RandomSampler(epoch_dataset)
        else:
            train_sampler = DistributedSampler(epoch_dataset)
        train_dataloader = DataLoader(
            epoch_dataset,
            sampler=train_sampler,
            batch_size=args.train_batch_size,
        )

        tr_loss = 0.0
        tr_att_loss = 0.0
        tr_rep_loss = 0.0
        student_model.train()
        nb_tr_examples, nb_tr_steps = 0, 0
        with tqdm(total=len(train_dataloader),
                  desc='Epoch {}'.format(epoch)) as pbar:
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc='Iteration', ascii=True)):
                batch = tuple(t.to(device) for t in batch)

                input_ids, input_mask, segment_ids, lm_label_ids, is_next = (
                    batch)
                if input_ids.size()[0] != args.train_batch_size:
                    continue

                att_loss = 0.0
                rep_loss = 0.0

                student_atts, student_reps = student_model(
                    input_ids, segment_ids, input_mask)
                teacher_reps, teacher_atts, _ = teacher_model(
                    input_ids, segment_ids, input_mask)
                teacher_reps = [
                    teacher_rep.detach() for teacher_rep in teacher_reps
                ]  # speedup 1.5x
                teacher_atts = [
                    teacher_att.detach() for teacher_att in teacher_atts
                ]

                teacher_layer_num = len(teacher_atts)
                student_layer_num = len(student_atts)
                assert teacher_layer_num % student_layer_num == 0
                layers_per_block = int(teacher_layer_num / student_layer_num)
                new_teacher_atts = [
                    teacher_atts[i * layers_per_block + layers_per_block - 1]
                    for i in range(student_layer_num)
                ]

                for student_att, teacher_att in zip(student_atts,
                                                    new_teacher_atts):
                    student_att = torch.where(
                        student_att <= -1e2,
                        torch.zeros_like(student_att).to(device),
                        student_att,
                    )
                    teacher_att = torch.where(
                        teacher_att <= -1e2,
                        torch.zeros_like(teacher_att).to(device),
                        teacher_att,
                    )
                    att_loss += loss_mse(student_att, teacher_att)

                new_teacher_reps = [
                    teacher_reps[i * layers_per_block]
                    for i in range(student_layer_num + 1)
                ]
                new_student_reps = student_reps

                for student_rep, teacher_rep in zip(new_student_reps,
                                                    new_teacher_reps):
                    rep_loss += loss_mse(student_rep, teacher_rep)

                loss = att_loss + rep_loss

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_att_loss += att_loss.item()
                tr_rep_loss += rep_loss.item()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                pbar.update(1)

                mean_loss = (tr_loss * args.gradient_accumulation_steps /
                             nb_tr_steps)
                mean_att_loss = (tr_att_loss *
                                 args.gradient_accumulation_steps /
                                 nb_tr_steps)
                mean_rep_loss = (tr_rep_loss *
                                 args.gradient_accumulation_steps /
                                 nb_tr_steps)

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                    if (global_step + 1) % args.eval_step == 0:
                        result = {}
                        result['global_step'] = global_step
                        result['loss'] = mean_loss
                        result['att_loss'] = mean_att_loss
                        result['rep_loss'] = mean_rep_loss
                        output_eval_file = os.path.join(
                            args.output_dir, 'log.txt')
                        with open(output_eval_file, 'a') as writer:
                            logger.info('***** Eval results *****')
                            for key in sorted(result.keys()):
                                logger.info('  %s = %s', key, str(result[key]))
                                writer.write('%s = %s\n' %
                                             (key, str(result[key])))

                        # Save a trained model
                        model_name = 'step_{}_{}'.format(
                            global_step, WEIGHTS_NAME)
                        logging.info(
                            '** ** * Saving fine-tuned model ** ** * ')
                        # Only save the model it-self
                        model_to_save = (student_model.module if hasattr(
                            student_model, 'module') else student_model)

                        output_model_file = os.path.join(
                            args.output_dir, model_name)
                        output_config_file = os.path.join(
                            args.output_dir, CONFIG_NAME)

                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        model_to_save.config.to_json_file(output_config_file)
                        tokenizer.save_vocabulary(args.output_dir)

                        if oncloud:
                            logging.info(
                                mox.file.list_directory(args.output_dir,
                                                        recursive=True))
                            logging.info(
                                mox.file.list_directory('.', recursive=True))
                            mox.file.copy_parallel(args.output_dir,
                                                   args.data_url)
                            mox.file.copy_parallel('.', args.data_url)

            model_name = 'step_{}_{}'.format(global_step, WEIGHTS_NAME)
            logging.info('** ** * Saving fine-tuned model ** ** * ')
            model_to_save = (student_model.module if hasattr(
                student_model, 'module') else student_model)

            output_model_file = os.path.join(args.output_dir, model_name)
            output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

            torch.save(model_to_save.state_dict(), output_model_file)
            model_to_save.config.to_json_file(output_config_file)
            tokenizer.save_vocabulary(args.output_dir)

            if oncloud:
                logging.info(
                    mox.file.list_directory(args.output_dir, recursive=True))
                logging.info(mox.file.list_directory('.', recursive=True))
                mox.file.copy_parallel(args.output_dir, args.data_url)
                mox.file.copy_parallel('.', args.data_url)
示例#3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--train_file_path",
                        default=None,
                        type=str,
                        required=True)

    # Required parameters
    parser.add_argument("--teacher_model",
                        default=None,
                        type=str,
                        required=True)
    parser.add_argument("--student_model",
                        default=None,
                        type=str,
                        required=True)
    parser.add_argument("--output_dir", default=None, type=str, required=True)

    # Other parameters
    parser.add_argument(
        "--max_seq_len",
        default=128,
        type=int,
        help="The maximum total input sequence length after WordPiece \n"
        " tokenization. Sequences longer than this will be truncated, \n"
        "and sequences shorter than this will be padded.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--weight_decay',
                        '--wd',
                        default=1e-1,
                        type=float,
                        metavar='W',
                        help='weight decay')
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing \n"
        "a backward/update pass.")
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--continue_train',
                        action='store_true',
                        help='Whether to train from checkpoints')

    # Additional arguments
    parser.add_argument('--eval_step', type=int, default=1000)

    # This is used for running on Huawei Cloud.
    parser.add_argument('--data_url', type=str, default="")

    args = parser.parse_args()
    logger.info('args:{}'.format(args))

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.teacher_model,
                                              do_lower_case=args.do_lower_case)

    dataset = PregeneratedDataset(args.train_file_path,
                                  tokenizer,
                                  max_seq_len=args.max_seq_len)
    total_train_examples = len(dataset)

    num_train_optimization_steps = int(
        total_train_examples / args.train_batch_size /
        args.gradient_accumulation_steps * args.num_train_epochs)
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        ) * args.num_train_epochs

    if args.continue_train:
        student_model = TinyBertForPreTraining.from_pretrained(
            args.student_model)
    else:
        student_model = TinyBertForPreTraining.from_scratch(args.student_model)
    teacher_model = BertModel.from_pretrained(args.teacher_model)

    # student_model = TinyBertForPreTraining.from_scratch(args.student_model, fit_size=teacher_model.config.hidden_size)
    student_model.to(device)
    teacher_model.to(device)

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        teacher_model = DDP(teacher_model)
    elif n_gpu > 1:
        student_model = torch.nn.DataParallel(student_model)
        teacher_model = torch.nn.DataParallel(teacher_model)

    size = 0
    for n, p in student_model.named_parameters():
        logger.info('n: {}'.format(n))
        logger.info('p: {}'.format(p.nelement()))
        size += p.nelement()

    logger.info('Total parameters: {}'.format(size))

    # Prepare optimizer
    param_optimizer = list(student_model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    loss_mse = MSELoss()
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=num_train_optimization_steps)

    logging.info("***** Running training *****")
    logging.info("  Num examples = {}".format(total_train_examples))
    logging.info("  Batch size = %d", args.train_batch_size)
    logging.info("  Num steps = %d", num_train_optimization_steps)

    if 1:
        if args.local_rank == -1:
            train_sampler = RandomSampler(dataset)
        else:
            train_sampler = DistributedSampler(dataset)
        train_dataloader = DataLoader(dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        tr_loss = 0.
        tr_att_loss = 0.
        tr_rep_loss = 0.
        student_model.train()
        global_step = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for epoch in range(int(args.num_train_epochs)):
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration", ascii=True)):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids = batch
                if input_ids.size()[0] != args.train_batch_size:
                    continue

                att_loss = 0.
                rep_loss = 0.

                student_atts, student_reps = student_model(
                    input_ids, segment_ids, input_mask)
                teacher_reps, teacher_atts, _ = teacher_model(
                    input_ids, segment_ids, input_mask)
                # speedup 1.5x
                teacher_reps = [
                    teacher_rep.detach() for teacher_rep in teacher_reps
                ]
                teacher_atts = [
                    teacher_att.detach() for teacher_att in teacher_atts
                ]

                teacher_layer_num = len(teacher_atts)
                student_layer_num = len(student_atts)
                assert teacher_layer_num % student_layer_num == 0
                layers_per_block = int(teacher_layer_num / student_layer_num)
                new_teacher_atts = [
                    teacher_atts[i * layers_per_block + layers_per_block - 1]
                    for i in range(student_layer_num)
                ]

                for student_att, teacher_att in zip(student_atts,
                                                    new_teacher_atts):
                    student_att = torch.where(
                        student_att <= -1e2,
                        torch.zeros_like(student_att).to(device), student_att)
                    teacher_att = torch.where(
                        teacher_att <= -1e2,
                        torch.zeros_like(teacher_att).to(device), teacher_att)
                    att_loss += loss_mse(student_att, teacher_att)

                new_teacher_reps = [
                    teacher_reps[i * layers_per_block]
                    for i in range(student_layer_num + 1)
                ]
                new_student_reps = student_reps

                for student_rep, teacher_rep in zip(new_student_reps,
                                                    new_teacher_reps):
                    rep_loss += loss_mse(student_rep, teacher_rep)

                loss = att_loss + rep_loss

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_att_loss += att_loss.item()
                tr_rep_loss += rep_loss.item()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
                mean_att_loss = tr_att_loss * args.gradient_accumulation_steps / nb_tr_steps
                mean_rep_loss = tr_rep_loss * args.gradient_accumulation_steps / nb_tr_steps
                if step % 100 == 0:
                    logger.info(f'mean_loss = {mean_loss}')

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                    if (global_step + 1) % args.eval_step == 0:
                        result = {}
                        result['global_step'] = global_step
                        result['loss'] = mean_loss
                        result['att_loss'] = mean_att_loss
                        result['rep_loss'] = mean_rep_loss
                        output_eval_file = os.path.join(
                            args.output_dir, "log.txt")
                        with open(output_eval_file, "a") as writer:
                            logger.info("***** Eval results *****")
                            for key in sorted(result.keys()):
                                logger.info("  %s = %s", key, str(result[key]))
                                writer.write("%s = %s\n" %
                                             (key, str(result[key])))

                        # Save a trained model
                        prefix = f"step_{step}"
                        save_model(prefix, student_model, args.output_dir)

            prefix = f"epoch_{epoch}"
            save_model(prefix, student_model, args.output_dir)