Exemple #1
0
    def train(self,
              train_data,
              dev_data=None,
              resume_path=None,
              start_epoch=1,
              state_to_save=dict()):
        train_dataloader = self.build_train_dataloader(train_data)
        num_training_steps = len(
            train_dataloader
        ) // self.gradient_accumulation_steps * self.num_train_epochs
        self.steps_in_epoch = len(train_dataloader)
        if self.scheduler is None:
            self.scheduler = self.build_lr_scheduler(num_training_steps)
        self.resume_from_checkpoint(resume_path=resume_path)
        self.build_model_warp()
        self.print_summary(len(train_data), num_training_steps)
        self.optimizer.zero_grad()
        seed_everything(
            self.opts.seed, verbose=False
        )  # Added here for reproductibility (even between python 2 and 3)
        if self.opts.logging_steps < 0:
            self.opts.logging_steps = len(
                train_dataloader) // self.gradient_accumulation_steps
            self.opts.logging_steps = max(1, self.opts.logging_steps)
        if self.opts.save_steps < 0:
            self.opts.save_steps = len(
                train_dataloader) // self.gradient_accumulation_steps
            self.opts.save_steps = max(1, self.opts.save_steps)
        self.build_record_tracker()
        self.reset_metrics()
        pbar = ProgressBar(n_total=len(train_dataloader),
                           desc='Training',
                           num_epochs=self.num_train_epochs)
        for epoch in range(start_epoch, int(self.num_train_epochs) + 1):
            pbar.epoch(current_epoch=epoch)
            for step, batch in enumerate(train_dataloader):
                outputs, should_logging, should_save = self.train_step(
                    step, batch)
                if outputs is not None:
                    if self.opts.ema_enable:
                        self.model_ema.update(self.model)
                    pbar.step(step, {'loss': outputs['loss'].item()})
                if (self.opts.logging_steps > 0 and self.global_step > 0) and \
                        should_logging and self.opts.evaluate_during_training:
                    self.evaluate(dev_data)
                    if self.opts.ema_enable and self.model_ema is not None:
                        self.evaluate(dev_data, prefix_metric='ema')
                    if hasattr(self.writer, 'save'):
                        self.writer.save()
                if (self.opts.save_steps > 0
                        and self.global_step > 0) and should_save:
                    # model checkpoint
                    if self.model_checkpoint:
                        state = self.build_state_object(**state_to_save)
                        if self.opts.evaluate_during_training:
                            if self.model_checkpoint.monitor not in self.records[
                                    'result']:
                                msg = (
                                    "There were expected keys in the eval result: "
                                    f"{', '.join(list(self.records['result'].keys()))}, "
                                    f"but get {self.model_checkpoint.monitor}."
                                )
                                raise TypeError(msg)
                            self.model_checkpoint.step(
                                state=state,
                                current=self.records['result'][
                                    self.model_checkpoint.monitor])
                        else:
                            self.model_checkpoint.step(state=state,
                                                       current=None)

            # early_stopping
            if self.early_stopping:
                if self.early_stopping.monitor not in self.records['result']:
                    msg = (
                        "There were expected keys in the eval result: "
                        f"{', '.join(list(self.records['result'].keys()))}, "
                        f"but get {self.early_stopping.monitor}.")
                    raise TypeError(msg)
                self.early_stopping.step(current=self.records['result'][
                    self.early_stopping.monitor])
                if self.early_stopping.stop_training:
                    break
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        if self.writer:
            self.writer.close()
Exemple #2
0
def main():
    opts = Argparser().get_training_arguments()
    logger = Logger(opts=opts)
    # device
    logger.info("initializing device")
    opts.device, opts.device_num = prepare_device(opts.device_id)
    seed_everything(opts.seed)
    config_class, model_class, tokenizer_class = MODEL_CLASSES[opts.model_type]
    # data processor
    logger.info("initializing data processor")
    tokenizer = tokenizer_class.from_pretrained(
        opts.pretrained_model_path, do_lower_case=opts.do_lower_case)
    train_dataset = load_data(opts.train_input_file, opts.data_dir, "train",
                              tokenizer, opts.train_max_seq_length)
    dev_dataset = load_data(opts.eval_input_file, opts.data_dir, "dev",
                            tokenizer, opts.eval_max_seq_length)
    test_dataset = load_data(opts.test_input_file, opts.data_dir, "test",
                             tokenizer, opts.test_max_seq_length)
    opts.num_labels = train_dataset.num_labels
    opts.label2id = CnerDataset.label2id()
    opts.id2label = CnerDataset.id2label()

    # model
    logger.info("initializing model and config")
    config = config_class.from_pretrained(opts.pretrained_model_path,
                                          num_labels=opts.num_labels,
                                          label2id=opts.label2id,
                                          id2label=opts.id2label)
    model = model_class.from_pretrained(opts.pretrained_model_path,
                                        config=config)
    model.to(opts.device)

    # trainer
    logger.info("initializing traniner")
    labels = {
        label.split('-')[1]
        for label in CnerDataset.get_labels() if '-' in label
    }
    metrics = [
        SequenceLabelingScore(labels=labels, average='micro', schema='BIOS')
    ]
    trainer = SequenceLabelingTrainer(opts=opts,
                                      model=model,
                                      tokenizer=tokenizer,
                                      metrics=metrics,
                                      logger=logger)
    # do train
    if opts.do_train:
        trainer.train(train_data=train_dataset, dev_data=dev_dataset)
    if opts.do_eval:
        checkpoints = []
        if opts.checkpoint_predict_code is not None:
            checkpoint = os.path.join(opts.output_dir,
                                      opts.checkpoint_predict_code)
            check_dir(checkpoint)
            checkpoints.append(checkpoint)
        if opts.eval_all_checkpoints:
            checkpoints = find_all_checkpoints(checkpoint_dir=opts.output_dir)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            prefix = checkpoint.split("/")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(opts.device)
            trainer.model = model
            trainer.evaluate(dev_data=dev_dataset,
                             save_result=True,
                             save_dir=prefix)

    if opts.do_predict:
        checkpoints = []
        if opts.checkpoint_predict_code is not None:
            checkpoint = os.path.join(opts.output_dir,
                                      opts.checkpoint_predict_code)
            check_dir(checkpoint)
            checkpoints.append(checkpoint)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            prefix = checkpoint.split("/")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(opts.device)
            trainer.model = model
            trainer.predict(test_data=test_dataset,
                            save_result=True,
                            save_dir=prefix)
Exemple #3
0
def main():
    parser = Argparser.get_training_parser()
    group = parser.add_argument_group(title="global pointer",
                                      description="Global pointer")
    group.add_argument("--decode_thresh", type=float, default=0.0)
    group.add_argument('--pe_dim',
                       default=64,
                       type=int,
                       help='The dim of Positional embedding')
    group.add_argument('--use_rope', action='store_true')
    opts = parser.parse_args_from_parser(parser)
    logger = Logger(opts=opts)
    # device
    logger.info("initializing device")
    opts.device, opts.device_num = prepare_device(opts.device_id)
    seed_everything(opts.seed)
    config_class, model_class, tokenizer_class = MODEL_CLASSES[opts.model_type]
    # data processor
    logger.info("initializing data processor")
    tokenizer = tokenizer_class.from_pretrained(
        opts.pretrained_model_path, do_lower_case=opts.do_lower_case)
    train_dataset = load_data(opts.train_input_file, opts.data_dir, "train",
                              tokenizer, opts.train_max_seq_length)
    dev_dataset = load_data(opts.eval_input_file, opts.data_dir, "dev",
                            tokenizer, opts.eval_max_seq_length)
    test_dataset = load_data(opts.test_input_file, opts.data_dir, "test",
                             tokenizer, opts.test_max_seq_length)
    opts.num_labels = train_dataset.num_labels
    opts.label2id = CnerDataset.label2id()
    opts.id2label = CnerDataset.id2label()
    # model
    logger.info("initializing model and config")
    config, unused_kwargs = config_class.from_pretrained(
        opts.pretrained_model_path,
        return_unused_kwargs=True,
        pe_dim=opts.pe_dim,
        use_rope=opts.use_rope,
        num_labels=opts.num_labels,
        id2label=opts.id2label,
        label2id=opts.label2id,
        decode_thresh=opts.decode_thresh,
        max_seq_length=512)
    # FIXED: 默认`from_dict`中,只有config中有键才能设置值,这里强制设置
    for key, value in unused_kwargs.items():
        setattr(config, key, value)
    model = model_class.from_pretrained(opts.pretrained_model_path,
                                        config=config)
    model.to(opts.device)
    # trainer
    logger.info("initializing traniner")
    metrics = [
        SequenceLabelingScore(CnerDataset.get_labels(),
                              schema='BIOS',
                              average='micro')
    ]
    trainer = SequenceLabelingTrainer(opts=opts,
                                      model=model,
                                      tokenizer=tokenizer,
                                      metrics=metrics,
                                      logger=logger)
    # do train
    if opts.do_train:
        trainer.train(train_data=train_dataset, dev_data=dev_dataset)
    if opts.do_eval:
        checkpoints = []
        if opts.checkpoint_predict_code is not None:
            checkpoint = os.path.join(opts.output_dir,
                                      opts.checkpoint_predict_code)
            check_dir(checkpoint)
            checkpoints.append(checkpoint)
        if opts.eval_all_checkpoints:
            checkpoints = find_all_checkpoints(checkpoint_dir=opts.output_dir)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            prefix = checkpoint.split("/")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(opts.device)
            trainer.model = model
            trainer.evaluate(dev_data=dev_dataset,
                             save_result=True,
                             save_dir=prefix)
    if opts.do_predict:
        checkpoints = []
        if opts.checkpoint_predict_code is not None:
            checkpoint = os.path.join(opts.output_dir,
                                      opts.checkpoint_predict_code)
            check_dir(checkpoint)
            checkpoints.append(checkpoint)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            prefix = checkpoint.split("/")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(opts.device)
            trainer.model = model
            trainer.predict(test_data=test_dataset,
                            save_result=True,
                            save_dir=prefix)
Exemple #4
0
def main():
    opts = Argparser().get_training_arguments()
    logger = Logger(opts=opts)
    # device
    logger.info("initializing device")
    opts.device, opts.device_num = prepare_device(opts.device_id)
    seed_everything(opts.seed)
    config_class, model_class, tokenizer_class = MODEL_CLASSES[opts.model_type]
    # data processor
    logger.info("initializing data processor")
    tokenizer = tokenizer_class.from_pretrained(
        opts.pretrained_model_path, do_lower_case=opts.do_lower_case)
    train_dataset = load_data(opts.train_input_file, opts.data_dir, "train",
                              tokenizer, opts.train_max_seq_length)
    dev_dataset = load_data(opts.eval_input_file, opts.data_dir, "dev",
                            tokenizer, opts.eval_max_seq_length)
    opts.num_labels = train_dataset.num_labels
    # model
    logger.info("initializing model and config")
    config = config_class.from_pretrained(opts.pretrained_model_path,
                                          num_labels=opts.num_labels)
    model = model_class.from_pretrained(opts.pretrained_model_path,
                                        config=config)
    model.to(opts.device)
    # trainer
    logger.info("initializing traniner")
    trainer = TextClassifierTrainer(
        opts=opts,
        model=model,
        tokenizer=tokenizer,
        metrics=[MattewsCorrcoef(num_classes=opts.num_labels)],
        logger=logger)
    # do train
    if opts.do_train:
        trainer.train(train_data=train_dataset, dev_data=dev_dataset)
    if opts.do_eval:
        checkpoints = []
        if opts.checkpoint_predict_code is not None:
            checkpoint = os.path.join(opts.output_dir,
                                      opts.checkpoint_predict_code)
            check_dir(checkpoint)
            checkpoints.append(checkpoint)
        if opts.eval_all_checkpoints:
            checkpoints = find_all_checkpoints(checkpoint_dir=opts.output_dir)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            prefix = checkpoint.split("/")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(opts.device)
            trainer.model = model
            trainer.evaluate(dev_data=dev_dataset,
                             save_result=True,
                             save_dir=prefix)

    if opts.do_predict:
        test_dataset = load_data(opts.test_input_file, opts.data_dir, "test",
                                 tokenizer, opts.test_max_seq_length)
        checkpoints = []
        if opts.checkpoint_predict_code is not None:
            checkpoint = os.path.join(opts.output_dir,
                                      opts.checkpoint_predict_code)
            check_dir(checkpoint)
            checkpoints.append(checkpoint)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            prefix = checkpoint.split("/")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(opts.device)
            trainer.model = model
            trainer.predict(test_data=test_dataset,
                            save_result=True,
                            save_dir=prefix)