Exemplo n.º 1
0
    def from_pretrained(cls, pretrained_model_name, do_lower_case=True):
        """
    Instantiate a PreTrainedBertModel from a pre-trained model file.
    Download and cache the pre-trained model file if needed.
    """
        if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
            vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
        else:
            vocab_file = pretrained_model_name
        # redirect to the cache, if necessary
        try:
            resolved_vocab_file = cached_path(vocab_file)
            if resolved_vocab_file == vocab_file:

                logger.info("loading vocabulary file {}".format(vocab_file))
            else:
                logger.info(
                    "loading vocabulary file {} from cache at {}".format(
                        vocab_file, resolved_vocab_file))
            # Instantiate tokenizer.
            tokenizer = cls(resolved_vocab_file, do_lower_case)
        except FileNotFoundError:
            logger.error(
                "Model name '{}' was not found in model name list ({}). "
                "We assumed '{}' was a path or url but couldn't find any file "
                "associated to this path or url.".format(
                    pretrained_model_name,
                    ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
                    pretrained_model_name))
            tokenizer = None
        return tokenizer
Exemplo n.º 2
0
    def __init__(self, file_path, tokenizer, block_size):
        self.file_path = file_path
        self.tokenizer = tokenizer
        self.length = None

        self.block_size = block_size - tokenizer.num_special_tokens_to_add(
            pair=False)
        logger.info(f"block_size: {block_size}")
Exemplo n.º 3
0
    def __len__(self):
        if self.length:
            return self.length
        self.length = 0
        logger.info("caculating lines at %s", self.file_path)
        with open(self.file_path, "r") as f:
            for line in f.read().splitlines():
                if len(line) > 0 and not line.isspace():
                    self.length += 1

        return self.length
Exemplo n.º 4
0
def _rotate_checkpoints(args,
                        checkpoint_prefix="checkpoint",
                        use_mtime=False) -> None:
    if not args.save_total_limit:
        return
    if args.save_total_limit <= 0:
        return

    # Check if we should delete older checkpoint(s)
    checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix,
                                             use_mtime)
    if len(checkpoints_sorted) <= args.save_total_limit:
        return

    number_of_checkpoints_to_delete = max(
        0,
        len(checkpoints_sorted) - args.save_total_limit)
    checkpoints_to_be_deleted = checkpoints_sorted[:
                                                   number_of_checkpoints_to_delete]
    for checkpoint in checkpoints_to_be_deleted:
        logger.info(
            "Deleting older checkpoint [{}] due to args.save_total_limit".
            format(checkpoint))
        shutil.rmtree(checkpoint)
    def __init__(self,
                 tokenizer: PreTrainedTokenizer,
                 args,
                 dir_path: str,
                 block_size=1024):
        self.examples = []
        tokenizer_class = tokenizer.__class__.__name__
        cached_features_file = os.path.join(
            dir_path, args.model_type + "_cached2_maskedsents3_" +
            str(block_size) + "_" + tokenizer_class)

        if os.path.exists(cached_features_file) and not args.overwrite_cache:
            logger.info("Loading features from cached file %s",
                        cached_features_file)
            with open(cached_features_file, "rb") as handle:
                self.examples = pickle.load(handle)
        else:
            logger.info("Creating features from dataset file at %s", dir_path)
            good_docs = bad_docs = 0
            for filename in os.listdir(dir_path):
                try:
                    if not filename.endswith(".json"):
                        continue

                    path = os.path.join(dir_path, filename)
                    with open(path) as json_file:
                        data = json.load(json_file)
                        facts_doc = FactsDoc.Schema().load(data)

                    splitter = SentenceSplitter(language='en')
                    full_text_sentence_split = splitter.split(
                        text=facts_doc.text)
                    sent_one = full_text_sentence_split[START_SENT]
                    sent_two = full_text_sentence_split[END_SENT]
                    inbetween_text = " ".join(
                        full_text_sentence_split[START_SENT + 1:END_SENT])
                    tokenized_sent_one = tokenizer.encode(
                        sent_one,
                        add_special_tokens=False,
                        return_tensors="pt").squeeze(0)
                    tokenized_sent_two = tokenizer.encode(
                        sent_two,
                        add_special_tokens=False,
                        return_tensors="pt").squeeze(0)
                    tokenized_inbetween_text = tokenizer.encode(
                        inbetween_text,
                        add_special_tokens=False,
                        return_tensors="pt").squeeze(0)
                    full_text_tensor = torch.cat([
                        tokenized_sent_one, tokenized_inbetween_text,
                        tokenized_sent_two
                    ],
                                                 dim=0)
                    mask = torch.cat([
                        torch.ones(tokenized_sent_one.size()),
                        torch.zeros(tokenized_inbetween_text.size()),
                        torch.ones(tokenized_sent_two.size())
                    ])
                    self.examples.append((full_text_tensor, mask))
                    good_docs += 1
                except:
                    bad_docs += 1

            logger.info("finished creating examples for " + dir_path)
            logger.info(
                f"docs with exceptions = {bad_docs} fro total {bad_docs+good_docs}"
            )
            logger.info("Saving features into cached file %s",
                        cached_features_file)
            with open(cached_features_file, "wb") as handle:
                pickle.dump(self.examples,
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
Exemplo n.º 6
0
    def __init__(self,
                 tokenizer: PreTrainedTokenizer,
                 args,
                 dir_path: str,
                 block_size=512,
                 is_train=False):
        global GENRES_LIST
        if is_train:
            genere_counter = Counter()

        self.examples = []
        self.raw_examples = []
        tokenizer_class = tokenizer.__class__.__name__
        modelstr = "xlnet"  #args.model_type
        cached_features_file = os.path.join(
            dir_path, modelstr + "x_cached2_maskedfactid3_" + str(block_size) +
            "_" + tokenizer_class)

        if os.path.exists(cached_features_file) and not args.overwrite_cache:
            logger.info("Loading features from cached file %s",
                        cached_features_file)
            with open(cached_features_file, "rb") as handle:
                self.examples = pickle.load(handle)

            base_path = os.path.dirname(
                dir_path) if dir_path[-1] != "/" else os.path.dirname(
                    dir_path[:-1])
            genre_list_path = os.path.join(base_path, "genre_list.npy")
            GENRES_LIST = np.load(genre_list_path).tolist()
        else:
            logger.info("Creating features from dataset file at %s", dir_path)

            docs_with_facts_counter = docs_without_facts_counter = long_docs = bad_docs = good_docs = 0
            self.facts_docs = []
            for filename in os.listdir(dir_path):
                try:
                    if not filename.endswith(".json"):
                        continue

                    path = os.path.join(dir_path, filename)
                    with open(path) as json_file:
                        data = json.load(json_file)
                        facts_doc = FactsDoc.Schema().load(data)
                        if len(facts_doc.openfacts) < FACTS_NUM:
                            docs_without_facts_counter += 1
                            continue

                    docs_with_facts_counter += 1
                    tokenized_text = tokenizer.tokenize(facts_doc.text)
                    encoded_text =  tokenizer.encode(facts_doc.text,add_special_tokens=False, return_tensors="pt")\
                        .squeeze(0)

                    # if encoded_text.size(-1) > (block_size-20):
                    #     long_docs+=1
                    #     continue

                    prefix_tokens = [
                    ]  #[tokenizer.additional_special_tokens_ids[1]]
                    ids_text_no_prefix = tokenizer.convert_tokens_to_ids(
                        tokenized_text)
                    ids_text = prefix_tokens + ids_text_no_prefix

                    strpos2index = get_strpos2index(tokenizer, encoded_text,
                                                    len(prefix_tokens))
                    top5salient_facts = sorted(facts_doc.openfacts,
                                               key=lambda x: x.salience)[:5]
                    topfacts_sorted_by_positon = sorted(
                        top5salient_facts, key=lambda x: x.position)
                    mask = torch.zeros(encoded_text.size())

                    genre = "unknown"  #facts_doc.Genre.split(",")[0]

                    if is_train:
                        genere_counter[genre] += 1

                    for i in range(len(prefix_tokens)):
                        mask[i] = 1

                    fact_id = 0
                    for fact in topfacts_sorted_by_positon:
                        fact_id += 1
                        for tok in fact.token:
                            for i in range(tok.start, tok.end):
                                if i not in strpos2index:
                                    print(
                                        f"{i} out of index in {tok} from {fact.text}"
                                    )
                                    raise LookupError
                                index = strpos2index[i]
                                mask[index] = fact_id

                    full_text_tensor = torch.tensor(ids_text)
                    self.raw_examples.append((full_text_tensor, mask, genre))
                    good_docs += 1
                except:
                    bad_docs += 1

            base_path = os.path.dirname(
                dir_path) if dir_path[-1] != "/" else os.path.dirname(
                    dir_path[:-1])
            genre_list_path = os.path.join(base_path, "genre_list.npy")
            if is_train:
                GENRES_LIST = [
                    k for k, v in genere_counter.items()
                    if v > GENRE_APPEARENCE_MIN
                ]
                np.save(genre_list_path, np.array(GENRES_LIST))
            else:
                GENRES_LIST = np.load(genre_list_path).tolist()

            for example in self.raw_examples:
                genre = example[2]
                if genre in GENRES_LIST:
                    genre_key = GENRES_LIST.index(genre)
                else:
                    genre_key = GENRES_LIST.index("unknown")

                self.examples.append(
                    (example[0], example[1], torch.tensor(genre_key)))

            logger.info("finished creating examples for " + dir_path)
            logger.info("docs turnes to examples = " + str(good_docs))
            logger.info("len of examples = " + str(len(self.examples)))
            logger.info("docs with facts = " + str(docs_with_facts_counter))
            logger.info("docs without enough facts = " +
                        str(docs_without_facts_counter))
            logger.info("docs too long = " + str(long_docs))
            logger.info("docs with exceptions = " + str(bad_docs))
            logger.info("Saving features into cached file %s",
                        cached_features_file)
            with open(cached_features_file, "wb") as handle:
                pickle.dump(self.examples,
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
    def __init__(self,
                 tokenizer: PreTrainedTokenizer,
                 args,
                 dir_path: str,
                 block_size=1024):
        self.examples = []
        tokenizer_class = tokenizer.__class__.__name__
        cached_features_file = os.path.join(
            dir_path, args.model_type + "_cached2_maskedplm2_" +
            str(block_size) + "_" + tokenizer_class)

        if os.path.exists(cached_features_file) and not args.overwrite_cache:
            logger.info("Loading features from cached file %s",
                        cached_features_file)
            with open(cached_features_file, "rb") as handle:
                self.examples = pickle.load(handle)
        else:
            logger.info("Creating features from dataset file at %s", dir_path)

            docs_with_facts_counter = docs_without_facts_counter = long_docs = bad_docs = 0
            self.facts_docs = []
            for filename in os.listdir(dir_path):
                try:
                    if not filename.endswith(".json"):
                        continue

                    path = os.path.join(dir_path, filename)
                    with open(path) as json_file:
                        data = json.load(json_file)
                        facts_doc = FactsDoc.Schema().load(data)
                        if len(facts_doc.openfacts) < FACTS_NUM:
                            docs_without_facts_counter += 1
                            continue

                    docs_with_facts_counter += 1
                    tokenized_text = tokenizer.tokenize(facts_doc.text)
                    encoded_text =  tokenizer.encode(facts_doc.text,add_special_tokens=False, return_tensors="pt")\
                        .squeeze(0)

                    # if encoded_text.size(-1) > (block_size-20):
                    #     long_docs+=1
                    #     continue

                    prefix_tokens = [
                    ]  #[tokenizer.additional_special_tokens_ids[1]]
                    ids_text_no_prefix = tokenizer.convert_tokens_to_ids(
                        tokenized_text)
                    ids_text = prefix_tokens + ids_text_no_prefix

                    strpos2index = get_strpos2index(tokenizer, encoded_text,
                                                    len(prefix_tokens))
                    top5salient_facts = sorted(facts_doc.openfacts,
                                               key=lambda x: x.salience)[:5]
                    topfacts_sorted_by_positon = sorted(
                        top5salient_facts, key=lambda x: x.position)
                    mask = torch.zeros(encoded_text.size())
                    for i in range(len(prefix_tokens)):
                        mask[i] = 1

                    for fact in topfacts_sorted_by_positon:
                        for tok in fact.token:
                            for i in range(tok.start, tok.end):
                                if i not in strpos2index:
                                    print(
                                        f"{i} out of index in {tok} from {fact.text}"
                                    )
                                    raise LookupError
                                index = strpos2index[i]
                                mask[index] = 1

                    full_text_tensor = torch.tensor(ids_text)
                    self.examples.append((full_text_tensor, mask))
                except:
                    bad_docs += 1

            logger.info("finished creating examples for " + dir_path)
            logger.info("docs with facts = " + str(docs_with_facts_counter))
            logger.info("docs without enough facts = " +
                        str(docs_without_facts_counter))
            logger.info("docs with exceptions = " + str(bad_docs))
            logger.info("Saving features into cached file %s",
                        cached_features_file)
            with open(cached_features_file, "wb") as handle:
                pickle.dump(self.examples,
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
Exemplo n.º 8
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        # tb_writer = SummaryWriter()
        tb_writer = SummaryWriter(
            log_dir=os.path.join(args.output_dir, "log", "train"))

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples):
        examples = [x[0] for x in examples]
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=None)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model_to_resize = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])

    set_seed(args)  # Added here for reproducibility

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0],
                              file=sys.stdout,
                              mininterval=10)
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue
            inputs, labels, _ = batch
            # torch.set_printoptions(profile="full")
            # print(f"Inputs : {tokenizer.convert_ids_to_tokens(inputs.tolist())}")
            # print(f"Labels : {tokenizer.convert_ids_to_tokens(labels.tolist())}")
            # exit(0)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)

            model.train()
            outputs = model(inputs, labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1
                eval_ppl = -1
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                        eval_ppl = results["perplexity"]
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}-{}".format(checkpoint_prefix,
                                          round(eval_ppl.item(), 5),
                                          global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemplo n.º 9
0
    def __init__(self,
                 tokenizer: PreTrainedTokenizer,
                 args,
                 dir_path: str,
                 block_size=1024):
        self.examples = []
        tokenizer_class = tokenizer.__class__.__name__
        cached_features_file = os.path.join(
            dir_path, args.model_type + "_cache43shrt_maskedplm2_" +
            str(block_size) + "_" + tokenizer_class)

        if os.path.exists(cached_features_file) and not args.overwrite_cache:
            logger.info("Loading features from cached file %s",
                        cached_features_file)
            with open(cached_features_file, "rb") as handle:
                self.examples = pickle.load(handle)
        else:
            logger.info("Creating features from dataset file at %s", dir_path)

            docs_with_facts_counter = docs_without_facts_counter = long_docs = bad_docs = 0
            self.facts_docs = []
            for filename in os.listdir(dir_path):
                try:
                    if not filename.endswith(".json"):
                        continue

                    path = os.path.join(dir_path, filename)
                    with open(path) as json_file:
                        data = json.load(json_file)
                        facts_doc = FactsDoc.Schema().load(data)
                        if len(facts_doc.openfacts) < FACTS_NUM:
                            docs_without_facts_counter += 1
                            continue

                    docs_with_facts_counter += 1
                    tokenized_text = tokenizer.tokenize(facts_doc.text)
                    encoded_text =  tokenizer.encode(facts_doc.text,add_special_tokens=False, return_tensors="pt")\
                        .squeeze(0)

                    if encoded_text.size(-1) > (block_size - 20):
                        long_docs += 1
                        continue

                    selected_facts = sorted(facts_doc.openfacts,
                                            key=lambda x: x.relationEnd)[:1]

                    prefix_tokens = [
                    ]  #[tokenizer.additional_special_tokens_ids[1]]
                    ids_text_no_prefix = tokenizer.convert_tokens_to_ids(
                        tokenized_text)
                    ids_text = prefix_tokens + ids_text_no_prefix

                    strpos2index = get_strpos2index(tokenizer, encoded_text,
                                                    len(prefix_tokens))
                    splitter = SentenceSplitter(language='en')
                    full_text_sentence_split = splitter.split(
                        text=facts_doc.text)

                    partial_text_sent_split = full_text_sentence_split[:
                                                                       SENTENCE_NUM
                                                                       + 1]

                    partial_encoded_sent_split= [tokenizer.encode(t, add_special_tokens=False, return_tensors="pt") \
                        .squeeze(0) for t in partial_text_sent_split]

                    partial_encoded = torch.cat(partial_encoded_sent_split,
                                                dim=0)
                    mask = torch.zeros(partial_encoded.size())
                    for i in range(len(prefix_tokens)):
                        mask[i] = 1

                    for fact in selected_facts:
                        for tok in fact.token:
                            for i in range(tok.start, tok.end):
                                if i not in strpos2index:
                                    print(
                                        f"{i} out of index in {tok} from {fact.text}"
                                    )
                                    raise LookupError
                                index = strpos2index[i]
                                mask[index] = 1
                    masked_sent = 0
                    for i, sent in enumerate(partial_encoded_sent_split):
                        start = sum([
                            s.size(0) for s in partial_encoded_sent_split[:i]
                        ])
                        end = start + sent.size(0)
                        if mask[start:end].sum() != 0:
                            masked_sent = i

                    if masked_sent == 0:
                        sents_to_expose = [1, 5]
                    elif masked_sent == 5:
                        sents_to_expose = [0, 4]
                    else:
                        sents_to_expose = [0, 5]

                    for sent_idx in sents_to_expose:
                        start = sum([
                            s.size(0)
                            for s in partial_encoded_sent_split[:sent_idx]
                        ])
                        end = start + partial_encoded_sent_split[
                            sent_idx].size(0)
                        mask[start:end] = 1

                    full_text_tensor = torch.tensor(partial_encoded)
                    self.examples.append((full_text_tensor, mask))
                except:
                    bad_docs += 1

            logger.info("finished creating examples for " + dir_path)
            logger.info("docs with facts = " + str(docs_with_facts_counter))
            logger.info("docs without enough facts = " +
                        str(docs_without_facts_counter))
            logger.info("docs with exceptions = " + str(bad_docs))
            logger.info("Saving features into cached file %s",
                        cached_features_file)
            with open(cached_features_file, "wb") as handle:
                pickle.dump(self.examples,
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
Exemplo n.º 10
0
    def prediction_loop(
            self,
            dataloader: DataLoader,
            description: str,
            prediction_loss_only: Optional[bool] = None) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.

        Works both with or without labels.
        """
        if hasattr(self, "_prediction_loop"):
            warnings.warn(
                "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
                FutureWarning,
            )
            return self._prediction_loop(
                dataloader,
                description,
                prediction_loss_only=prediction_loss_only)

        prediction_loss_only = (prediction_loss_only
                                if prediction_loss_only is not None else
                                self.args.prediction_loss_only)

        assert not getattr(
            self.model.config, "output_attentions", False
        ), "The prediction loop does not work with `output_attentions=True`."
        assert not getattr(
            self.model.config, "output_hidden_states", False
        ), "The prediction loop does not work with `output_hidden_states=True`."

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = self.model
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        batch_size = dataloader.batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
        eval_losses: List[float] = []
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
        model.eval()

        if self.args.past_index >= 0:
            self._past = None

        disable_tqdm = not self.is_local_process_zero(
        ) or self.args.disable_tqdm
        for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
            loss, logits, labels = self.prediction_step(
                model, inputs, prediction_loss_only)
            batch_size = inputs[list(inputs.keys())[0]].shape[0]
            if loss is not None:
                eval_losses.extend([loss] * batch_size)
            if logits is not None:
                preds = logits if preds is None else nested_concat(
                    preds, logits, dim=0)
            if labels is not None:
                label_ids = labels if label_ids is None else nested_concat(
                    label_ids, labels, dim=0)

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = distributed_concat(
                    preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = distributed_concat(
                    label_ids,
                    num_total_examples=self.num_examples(dataloader))

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = nested_numpify(preds)
        if label_ids is not None:
            label_ids = nested_numpify(label_ids)

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(
                EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            if self.args.local_rank != -1:
                metrics["eval_loss"] = (distributed_broadcast_scalars(
                    eval_losses,
                    num_total_examples=self.num_examples(
                        dataloader)).mean().item())
            else:
                metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds,
                                label_ids=label_ids,
                                metrics=metrics)
Exemplo n.º 11
0
    def train(self,
              model_path: Optional[str] = None,
              trial: Union["optuna.Trial", Dict[str, Any]] = None):
        """
        Main training entry point.

        Args:
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
            trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
                The trial run or the hyperparameter dictionary for hyperparameter search.
        """
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

        # Model re-init
        if self.model_init is not None:
            # Seed must be set before instantiating the model when using model_init.
            set_seed(self.args.seed)
            model = self.model_init()
            self.model = model.to(self.args.device)

            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None

        # Data loader and number of training steps
        train_dataloader = self.get_train_dataloader()
        num_update_steps_per_epoch = len(
            train_dataloader) // self.args.gradient_accumulation_steps
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                self.args.max_steps % num_update_steps_per_epoch > 0)
        else:
            t_total = int(num_update_steps_per_epoch *
                          self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs
            self.args.max_steps = t_total

        self.create_optimizer_and_scheduler(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (model_path is not None
                and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
                and os.path.isfile(os.path.join(model_path, "scheduler.pt"))):
            # Load in optimizer and scheduler states
            self.optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"),
                           map_location=self.args.device))
            self.lr_scheduler.load_state_dict(
                torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(),
                                       metric_dict={})

        # Train!

        total_train_batch_size = (self.args.train_batch_size *
                                  self.args.gradient_accumulation_steps *
                                  (torch.distributed.get_world_size()
                                   if self.args.local_rank != -1 else 1))
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d",
                    self.args.per_device_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        self.total_flos = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(
                    model_path.split("-")[-1].split(os.path.sep)[0])
                self.total_flos = getattr(model.config, "total_flos", 0)

                epochs_trained = self.global_step // num_update_steps_per_epoch
                steps_trained_in_current_epoch = self.global_step % (
                    num_update_steps_per_epoch)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            self.global_step)
                logger.info(
                    "  Continuing training from %d non-embedding floating-point operations",
                    self.total_flos)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                self.total_flos = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = torch.tensor(0.0).to(self.args.device)
        logging_loss_scalar = 0.0
        model.zero_grad()
        disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero(
        )
        train_pbar = trange(epochs_trained,
                            int(np.ceil(num_train_epochs)),
                            desc="Epoch",
                            disable=disable_tqdm)
        for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
            if isinstance(train_dataloader, DataLoader) and isinstance(
                    train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            epoch_iterator = train_dataloader

            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None
            epoch_number = 0
            epoch_pbar = tqdm(epoch_iterator,
                              desc="Iteration",
                              disable=disable_tqdm)
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    epoch_pbar.update(1)
                    continue

                tr_loss += self.training_step(model, inputs)
                self.total_flos += self.floating_point_ops(inputs)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):

                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   self.args.max_grad_norm)
                    self.optimizer.step()

                    self.lr_scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0
                            and self.global_step % self.args.logging_steps
                            == 0) or (self.global_step == 1
                                      and self.args.logging_first_step):
                        logs: Dict[str, float] = {}
                        tr_loss_scalar = tr_loss.item()
                        logs["loss"] = (tr_loss_scalar - logging_loss_scalar
                                        ) / self.args.logging_steps
                        # backward compatibility for pytorch schedulers

                        logging_loss_scalar = tr_loss_scalar

                        self.log(logs)

                    if (self.args.evaluation_strategy
                            == EvaluationStrategy.STEPS
                            and self.global_step % self.args.eval_steps == 0):
                        metrics = self.evaluate()
                        self._report_to_hp_search(trial, epoch, metrics)

                epoch_pbar.update(1)
                if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
                    metrics = self.evaluate()
                    self._report_to_hp_search(trial, epoch, metrics)
                if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
                    break
            epoch_pbar.close()
            train_pbar.update(1)
            metrics = self.evaluate()
            self._report_to_hp_search(trial, epoch, metrics)
            # In all cases (even distributed/parallel), self.model is always a reference
            # to the model we want to save.
            if hasattr(model, "module"):
                assert (
                    model.module is self.model
                ), f"Module {model.module} should be a reference to self.model"
            else:
                assert model is self.model, f"Model {model} should be a reference to self.model"
            # Save model checkpoint
            epoch_number += 1
            checkpoint_folder = f"checkpoints-epoch-{epoch_number}"

            output_dir = os.path.join(self.args.output_dir, checkpoint_folder)

            self.save_model(output_dir)

            if self.is_world_process_zero():
                self._rotate_checkpoints(use_mtime=True)

            torch.save(self.optimizer.state_dict(),
                       os.path.join(output_dir, "optimizer.pt"))
            torch.save(self.lr_scheduler.state_dict(),
                       os.path.join(output_dir, "scheduler.pt"))
            if self.args.tpu_metrics_debug or self.args.debug:
                logger.warning(
                    "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                    "configured. Check your training configuration if this is unexpected."
                )
            if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
                break

        train_pbar.close()
        if self.tb_writer:
            self.tb_writer.close()
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
Exemplo n.º 12
0
    def __init__(self,
                 tokenizer: PreTrainedTokenizer,
                 file_path: str,
                 block_size: int,
                 overwrite_cache=False):
        assert os.path.isfile(file_path)
        # Here, we do not cache the features, operating under the assumption
        # that we will soon use fast multithreaded tokenizers from the
        # `tokenizers` repo everywhere =)

        logger.info("Creating features from dataset file at %s", file_path)
        logger.info(f"block_size: {block_size}")
        block_size = block_size - tokenizer.num_special_tokens_to_add(
            pair=False)

        directory, filename = os.path.split(file_path)
        cached_features_file = os.path.join(
            directory,
            "cached_lm_{}_{}_{}".format(
                tokenizer.__class__.__name__,
                str(block_size),
                filename,
            ),
        )

        # Make sure only the first process in distributed training processes the dataset,
        # and the others will use the cache.
        lock_path = cached_features_file + ".lock"
        with FileLock(lock_path):

            if os.path.exists(cached_features_file) and not overwrite_cache:
                start = time.time()
                with open(cached_features_file, "rb") as handle:
                    self.examples = pickle.load(handle)
                logger.info(
                    f"Loading features from cached file {cached_features_file} [took %.3f s]",
                    time.time() - start)

            else:
                logger.info(
                    f"Creating features from dataset file at {directory}")

                with open(file_path, encoding="utf-8") as f:
                    lines = [
                        line for line in f.read().splitlines()
                        if (len(line) > 0 and not line.isspace())
                    ]

                line_length = len(lines)
                i = 0
                delta = 10000
                self.examples = []
                logger.info(f"data set length: {len(lines)}, tokenizing...")
                while i + delta < line_length:
                    # if i > 80000:
                    #     break
                    logger.info(f"tokenizing line: {i} ~ {i+delta}..")
                    batch_encoding = tokenizer(lines[i:i + delta],
                                               add_special_tokens=True,
                                               padding="max_length",
                                               truncation=True,
                                               max_length=block_size)
                    if batch_encoding["input_ids"]:
                        logger.info(
                            f"first input ids: {len(batch_encoding['input_ids'][0])}"
                        )
                    # assert all(len(x) == block_size for x in batch_encoding["input_ids"]), "some sample length are invalid!!"
                    self.examples.extend(batch_encoding["input_ids"])
                    i += delta
                if i < line_length <= i + delta:
                    logger.info(f"tokenizing line: {i} ~ {line_length}..")
                    batch_encoding = tokenizer(lines[i:],
                                               add_special_tokens=True,
                                               padding="max_length",
                                               truncation=True,
                                               max_length=block_size)
                    if batch_encoding["input_ids"]:
                        logger.info(
                            f"first input ids: {len(batch_encoding['input_ids'][0])}"
                        )
                    # assert all(len(x) == block_size for x in batch_encoding["input_ids"]), "some sample length are invalid!!"
                    self.examples.extend(batch_encoding["input_ids"])

                # Note that we are losing the last truncated example here for the sake of simplicity (no padding)
                # If your dataset is small, first you should loook for a bigger one :-) and second you
                # can change this behavior by adding (model specific) padding.

                start = time.time()
                with open(cached_features_file, "wb") as handle:
                    pickle.dump(self.examples,
                                handle,
                                protocol=pickle.HIGHEST_PROTOCOL)
                logger.info(
                    "Saving features into cached file %s [took %.3f s]",
                    cached_features_file,
                    time.time() - start)