Example #1
0
def get_model_and_tokenizer(args):
    config_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    model_config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None)
    config = BertForSeq2SeqConfig.from_exist_config(
        config=model_config,
        label_smoothing=args.label_smoothing,
        max_position_embeddings=args.max_source_seq_length +
        args.max_target_seq_length)

    logger.info("Model config for seq2seq: %s", str(config))

    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None)

    model = BertForSequenceToSequence.from_pretrained(
        args.model_name_or_path,
        config=config,
        model_type=args.model_type,
        reuse_position_embedding=True,
        cache_dir=args.cache_dir if args.cache_dir else None)

    return model, tokenizer
Example #2
0
def get_model_and_tokenizer(args):
    model_config = UnilmConfig.from_pretrained(
        args.config_name if args.config_name else 'unilm-base-cased',
        cache_dir=args.cache_dir if args.cache_dir else None)
    config = BertForSeq2SeqConfig.from_exist_config(
        config=model_config,
        label_smoothing=args.label_smoothing,
        max_position_embeddings=args.max_source_seq_length +
        args.max_target_seq_length)

    logger.info("Model config for seq2seq: %s", str(config))

    tokenizer = UnilmTokenizer.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else 'unilm-base-cased',
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None)

    generator = BertForSequenceToSequence.from_pretrained(
        'unilm-base-cased',
        config=config,
        model_type='unilm',
        reuse_position_embedding=True,
        cache_dir=args.cache_dir if args.cache_dir else None)
    generator.to(args.device)

    classifer = Classifier(config.hidden_size, args.num_labels)
    classifer.to(args.device)

    logger.info("Initialize retriever.")
    retriever = Retriever(args, tokenizer)
    return generator, classifer, tokenizer, retriever
Example #3
0
def get_model_and_tokenizer(args):
    config_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    # Hack to cope with updated version of Transformers API
    if args.model_type in ['minilm', 'unilm', 'xbert']:
        config_file = config_class.pretrained_config_archive_map[
            args.model_name_or_path]
        vocab_file = tokenizer_class.pretrained_vocab_files_map['vocab_file'][
            args.model_name_or_path]
        model_file = args.model_name_or_path
    elif os.path.exists(args.model_name_or_path):
        vocab_file = os.path.join(args.model_name_or_path, 'vocab.txt')
        config_file = os.path.join(args.model_name_or_path, 'config.json')
        model_file = os.path.join(args.model_name_or_path, 'pytorch_model.bin')
        assert os.path.exists(vocab_file)
        assert os.path.exists(config_file)
        assert os.path.exists(model_file)
    else:
        vocab_file = args.model_name_or_path
        config_file = args.model_name_or_path
        model_file = args.model_name_or_path

    model_config = config_class.from_pretrained(
        args.config_name if args.config_name else config_file,
        cache_dir=args.cache_dir if args.cache_dir else None)

    config = BertForSeq2SeqConfig.from_exist_config(
        config=model_config,
        label_smoothing=args.label_smoothing,
        max_position_embeddings=args.max_source_seq_length +
        args.max_target_seq_length)

    logger.info("Model config for seq2seq: %s", str(config))

    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else vocab_file,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None)

    model = BertForSequenceToSequence.from_pretrained(
        model_file,
        config=config,
        model_type=args.model_type,
        reuse_position_embedding=True,
        cache_dir=args.cache_dir if args.cache_dir else None)

    return model, tokenizer
Example #4
0
def get_model_and_tokenizer(args):
    config_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    model_config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None)
    config = BertForSeq2SeqConfig.from_exist_config(
        config=model_config,
        label_smoothing=args.label_smoothing,
        max_position_embeddings=args.max_source_seq_length +
        args.max_target_seq_length)

    logger.info("Model config for seq2seq: %s", str(config))

    if args.prepend_len:
        tgt_segments = [85] + list(range(100, 400, 15)) + [400]
        additional_special_tokens = [f'[unused{seg}]' for seg in tgt_segments]
        logger.info(f'additional_special_tokens: {additional_special_tokens}')
        tokenizer = tokenizer_class.from_pretrained(
            args.tokenizer_name
            if args.tokenizer_name else args.model_name_or_path,
            do_lower_case=args.do_lower_case,
            cache_dir=args.cache_dir if args.cache_dir else None,
            additional_special_tokens=additional_special_tokens)
    else:
        tokenizer = tokenizer_class.from_pretrained(
            args.tokenizer_name
            if args.tokenizer_name else args.model_name_or_path,
            do_lower_case=args.do_lower_case,
            cache_dir=args.cache_dir if args.cache_dir else None)

    model = BertForSequenceToSequence.from_pretrained(
        args.model_name_or_path,
        config=config,
        model_type=args.model_type,
        reuse_position_embedding=True,
        cache_dir=args.cache_dir if args.cache_dir else None)

    return model, tokenizer
def train(args, training_features, doc_features, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0] and args.log_dir:
        tb_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        tb_writer = None

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
    else:
        amp = None

    # model recover
    recover_step = utils.get_max_epoch_model(args.output_dir)

    # if recover_step:
    #     model_recover_checkpoint = os.path.join(args.output_dir, "model.{}.bin".format(recover_step))
    #     logger.info(" ** Recover model checkpoint in %s ** ", model_recover_checkpoint)
    #     model_state_dict = torch.load(model_recover_checkpoint, map_location='cpu')
    #     optimizer_recover_checkpoint = os.path.join(args.output_dir, "optim.{}.bin".format(recover_step))
    #     checkpoint_state_dict = torch.load(optimizer_recover_checkpoint, map_location='cpu')
    #     checkpoint_state_dict['model'] = model_state_dict
    # else:
    checkpoint_state_dict = None

    model.to(args.device)
    model, optimizer = prepare_for_training(args, model, checkpoint_state_dict, amp=amp)

    if args.n_gpu == 0 or args.no_cuda:
        per_node_train_batch_size = args.per_gpu_train_batch_size * args.gradient_accumulation_steps
    else:
        per_node_train_batch_size = args.per_gpu_train_batch_size * args.n_gpu * args.gradient_accumulation_steps
        
    train_batch_size = per_node_train_batch_size * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)
    global_step = recover_step if recover_step else 0

    if args.num_training_steps == -1:
        args.num_training_steps = int(args.num_training_epochs * len(training_features) / train_batch_size)

    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.num_training_steps, last_epoch=-1)

    if checkpoint_state_dict:
        scheduler.load_state_dict(checkpoint_state_dict["lr_scheduler"])

    train_dataset = utils.RetrievalSeq2seqDatasetForBert(
        features=training_features, max_source_len=args.max_source_seq_length,
        max_target_len=args.max_target_seq_length, vocab_size=tokenizer.vocab_size,
        cls_id=tokenizer.cls_token_id, sep_id=tokenizer.sep_token_id, pad_id=tokenizer.pad_token_id,
        mask_id=tokenizer.mask_token_id, random_prob=args.random_prob, keep_prob=args.keep_prob,
        offset=train_batch_size * global_step, num_training_instances=train_batch_size * args.num_training_steps,
    )

    concator = utils.Concator(
         max_source_len=args.max_source_seq_length,
        max_target_len=args.max_target_seq_length, vocab_size=tokenizer.vocab_size,
        cls_id=tokenizer.cls_token_id, sep_id=tokenizer.sep_token_id, pad_id=tokenizer.pad_token_id,
        mask_id=tokenizer.mask_token_id, random_prob=args.random_prob, keep_prob=args.keep_prob,
        offset=train_batch_size * global_step, num_training_instances=train_batch_size * args.num_training_steps,
    )

    if hasattr(model, "module"):
        model.module.concator = concator
    else:
        model.concator = concator

    # build documents embeds
    logger.info("Building embeds for %d documents" % len(doc_features))
    doc_dataset = utils.RetrievalSeq2seqDocDatasetForBert(
        features=doc_features, max_source_len=args.max_source_seq_length,
        max_target_len=args.max_target_seq_length, vocab_size=tokenizer.vocab_size,
        cls_id=tokenizer.cls_token_id, sep_id=tokenizer.sep_token_id, pad_id=tokenizer.pad_token_id,
        mask_id=tokenizer.mask_token_id, random_prob=args.random_prob, keep_prob=args.keep_prob,
        offset=train_batch_size * global_step, num_training_instances=train_batch_size * args.num_training_steps,
    )
    doc_sampler = SequentialSampler(doc_dataset)
    doc_dataloader = DataLoader(
            doc_dataset, sampler=doc_sampler,
            batch_size=per_node_train_batch_size // args.gradient_accumulation_steps,
            collate_fn=utils.batch_list_to_batch_tensors)
    doc_iterator = tqdm.tqdm(
            doc_dataloader, initial=global_step,
            desc="Embeding docs:", disable=args.local_rank not in [-1, 0])
    all_embeds = []

    model.eval()
    model.zero_grad()
    for step, batch in enumerate(doc_iterator):
        batch = tuple(t.to(args.device) for t in batch)
        with torch.no_grad():
            embeds = model.module.retrieval.get_embeds(batch[0]) if hasattr(model, "module") else model.retrieval.get_embeds(batch[0])
        all_embeds.extend(embeds.view(-1, 768).detach().cpu().tolist())
    
    if hasattr(model, "module"):
        model.module.retrieval.doc_embeds = torch.tensor(all_embeds, dtype=torch.float32)    

        model.module.retrieval.build_indexs_from_embeds(model.module.retrieval.doc_embeds)
    else:
        model.retrieval.doc_embeds = torch.tensor(all_embeds, dtype=torch.float32)    

        model.retrieval.build_indexs_from_embeds(model.retrieval.doc_embeds)

    logger.info("start training")

    if args.ckpt_path:
        logger.info("continue training from %s"%args.ckpt_path)
        config_class, tokenizer_class = MODEL_CLASSES[args.model_type]
        model_config = config_class.from_pretrained(
            args.config_name if args.config_name else args.model_name_or_path,
            cache_dir=args.cache_dir if args.cache_dir else None)
        config = BertForSeq2SeqConfig.from_exist_config(
            config=model_config, label_smoothing=args.label_smoothing,
            max_position_embeddings=args.max_source_seq_length + args.max_target_seq_length)
        model = BertForRetrievalSeq2Seq.from_pretrained(
        args.ckpt_path, config=config, model_type=args.model_type,
        reuse_position_embedding=True, retrieval=config,
        cache_dir=args.cache_dir if args.cache_dir else None)

        



    logger.info("Check dataset:")
    for i in range(5):
        source_ids, target_ids, num_source_tokens, num_target_tokens = train_dataset.__getitem__(i)
        logger.info("Instance-%d" % i)
        logger.info("Source tokens = %s" % " ".join(tokenizer.convert_ids_to_tokens(source_ids)))
        logger.info("Target tokens = %s" % " ".join(tokenizer.convert_ids_to_tokens(target_ids)))

    logger.info("Mode = %s" % str(model))

    # Train!
    logger.info("  ***** Running training *****  *")
    logger.info("  Num examples = %d", len(training_features))
    logger.info("  Num Epochs = %.2f", len(train_dataset) / len(training_features))
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Batch size per node = %d", per_node_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", train_batch_size)
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", args.num_training_steps)

    if args.num_training_steps <= global_step:
        logger.info("Training is done. Please use a new dir or clean this dir!")
    else:
        # The training features are shuffled
        train_sampler = SequentialSampler(train_dataset) \
            if args.local_rank == -1 else DistributedSampler(train_dataset, shuffle=False)
        train_dataloader = DataLoader(
            train_dataset, sampler=train_sampler,
            batch_size=per_node_train_batch_size // args.gradient_accumulation_steps,
            collate_fn=utils.batch_list_to_batch_tensors)

        train_iterator = tqdm.tqdm(
            train_dataloader, initial=global_step,
            desc="Iter (loss=X.XXX, lr=X.XXXXXXX)", disable=args.local_rank not in [-1, 0])

        model.train()
        model.zero_grad()

        tr_loss, logging_loss = 0.0, 0.0

        for step, batch in enumerate(train_iterator):
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {'source_ids': batch[0],
                      'target_ids': batch[1],
                    #   'pseudo_ids': batch[2],
                      'num_source_tokens': batch[2],
                      'num_target_tokens': batch[3]}
            loss = model(**inputs)
            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel (not distributed) training

            train_iterator.set_description('Iter (loss=%5.3f) lr=%9.7f' % (loss.item(), scheduler.get_lr()[0]))

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            logging_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logger.info("")
                    logger.info(" Step [%d ~ %d]: %.2f", global_step - args.logging_steps, global_step, logging_loss)
                    logging_loss = 0.0

                if args.local_rank in [-1, 0] and args.save_steps > 0 and \
                        (global_step % args.save_steps == 0 or global_step == args.num_training_steps):

                    save_path = os.path.join(args.output_dir, "ckpt-%d" % global_step)
                    os.makedirs(save_path, exist_ok=True)
                    model_to_save = model.module if hasattr(model, "module") else model
                    model_to_save.save_pretrained(save_path)

                    
                    # optim_to_save = {
                    #     "optimizer": optimizer.state_dict(),
                    #     "lr_scheduler": scheduler.state_dict(),
                    # }
                    # if args.fp16:
                    #     optim_to_save["amp"] = amp.state_dict()
                    # torch.save(
                    #     optim_to_save, os.path.join(args.output_dir, 'optim.{}.bin'.format(global_step)))

                    logger.info("Saving model checkpoint %d into %s", global_step, save_path)

    if args.local_rank in [-1, 0] and tb_writer:
        tb_writer.close()
    def __init__(
        self,
        model_name="unilm-base-cased",
        to_lower=False,
        cache_dir=".",
        load_model_from_dir=None,
        model_file_name=None,
        label_smoothing=0.1,
        max_seq_length=512,
        max_source_seq_length=464,
        max_target_seq_length=48,
    ):
        """
        Abstractive summarizer based on s2s-ft.

        Args:
            model_name (str, optional): Name of the model.
                Call `S2SAbstractiveSummarizer.list_supported_models()` to see all
                supported model names. Defaults to "unilm-base-cased".
            to_lower (bool, optional): Whether to convert all letters to lower case
                during tokenization. This is determined by if a cased model is used.
                Defaults to False, which corresponds to a cased model.
            cache_dir (str, optional): Directory to cache downloaded model files.
                Defaults to ".".
            load_model_from_dir (str, optional): Directory to load the model from. If
                model_file_name is not provided, assume model was saved by
                `:func:`~transformers.PreTrainedModel.save_pretrained`` and the
                directory should contain pytorch_model.bin and config.json.
                Defaults to None.
            model_file_name (str, optional): Name of the model file under
                `load_model_from_dir`. If provided, assume model was saved by
                `S2SAbstractiveSummarizer.save_model`.
            label_smoothing (float, optional): Alpha in label smoothing.
                Defaults to 0.1.
            max_seq_length (int, optional): Maximum length of the sequence that
                concatenates source sequence tokens, target sequence tokens, and
                special tokens like cls and sep. Defaults to 512.
            max_source_seq_length (int, optional): Maximum number of tokens in the
                source sequence after tokenization. Defaults to 464.
            max_target_seq_length (int, optional); Maximum number of tokens in the
                target sequence after tokenization. Defaults to 48.

        """

        if model_name not in self.list_supported_models():
            raise ValueError(
                "Model name {0} is not supported by {1}. "
                "Call '{1}.list_supported_models()' to get all supported model "
                "names.".format(model_name, self.__class__.__name__))
        model_class = MODEL_CLASS[model_name]
        config_class = CONFIG_CLASS[model_name]

        self._model_name = model_name
        self._model_type = _get_model_type(self._model_name)

        # self._bert_model_name is needed for BertForSeq2SeqDecoder
        if self._model_type != "bert":
            if self._model_type == "roberta":
                self._bert_model_name = (
                    self._model_name.replace("roberta", "bert") + "-cased")
            else:
                self._bert_model_name = "bert-" + self._model_name.split(
                    "-", 1)[-1]
        else:
            self._bert_model_name = self._model_name

        self.cache_dir = cache_dir
        self.load_model_from_dir = load_model_from_dir
        self.do_lower_case = to_lower
        self.max_seq_length = max_seq_length
        self.max_source_seq_length = max_source_seq_length
        self.max_target_seq_length = max_target_seq_length

        if load_model_from_dir is None:
            model_to_load = self._model_name
        elif model_file_name is None:
            # Assume model was saved by
            # `:func:`~transformers.PreTrainedModel.save_pretrained``,
            # The load_model_from_dir should contain pytorch_model.bin and config.json
            # and can be loaded by
            # `:func:`~transformers.PreTrainedModel.from_pretrained``.
            logger.info(
                "Loading cached model from {}".format(load_model_from_dir))
            model_to_load = load_model_from_dir
        else:
            # Assume model was saved by S2SAbstractiveSummarizer.save_model
            model_to_load = os.path.join(load_model_from_dir, model_file_name)
            logger.info("Loading cached model from {}".format(model_to_load))

        if load_model_from_dir is not None and model_file_name is None:
            # Assume config.json is in load_model_from_dir
            model_config = config_class.from_pretrained(load_model_from_dir,
                                                        cache_dir=cache_dir)
        else:
            model_config = config_class.from_pretrained(self._model_name,
                                                        cache_dir=cache_dir)

        # Convert regular model config to sequence to sequence config
        config = BertForSeq2SeqConfig.from_exist_config(
            config=model_config,
            label_smoothing=label_smoothing,
            max_position_embeddings=self.max_source_seq_length +
            self.max_target_seq_length,
        )
        logger.info("Model config for seq2seq: %s", str(config))

        self.model = model_class.from_pretrained(
            model_to_load,
            config=config,
            model_type=self._model_type,
            cache_dir=cache_dir,
            reuse_position_embedding=True,
        )

        self.tokenizer = TOKENIZER_CLASS[model_name].from_pretrained(
            self._model_name,
            do_lower_case=to_lower,
            cache_dir=cache_dir,
            output_loading_info=False,
        )