class WinterSchool_FineTuning:
    def __init__(self,
                 tokenizer_path,
                 train_ratio=None,
                 batch_size=None,
                 epoch=None):
        self.epoch = epoch
        self.batch = batch_size
        self.train_ratio = train_ratio
        self.set_device()
        self.build_BERT(tokenizer_path)
        self.trained = False

    def set_device(self):
        import torch
        self.device = torch.device(
            "cuda") if torch.cuda.is_available() else torch.device("cpu")
        if not torch.cuda.is_available():
            print(
                "주의! GPU를 사용함으로 설정하지 않으셨습니다. [런타임]-[런타임 유형 설정]에서 GPU 사용으로 설정해주세요."
            )
        else:
            print("GPU를 사용합니다. {}".format(torch.cuda.get_device_name(0)))

    def get_max_length(self, corpus, verbose=False) -> int:
        mxlen = 0
        for sent in corpus:
            if type(sent) is str:
                input_ids = self.tokenizer.tokenize(sent)
                mxlen = max(mxlen, len(input_ids))
        if verbose:
            print("max length is... ", mxlen)
        return mxlen

    def encode(self, corpus, labels=None, _tqdm=True, verbose=False):
        from tqdm.notebook import tqdm
        import torch

        self.corpus = corpus

        input_ids = []
        attention_masks = []
        if labels is not None:
            assert len(corpus) == len(labels)
        mxlen = self.get_max_length(corpus, verbose)
        if _tqdm:
            for sent in tqdm(corpus):
                encoded = self.tokenizer.encode_plus(
                    sent,
                    add_special_tokens=True,
                    max_length=mxlen,
                    truncation=True,
                    pad_to_max_length=True,
                    return_attention_mask=True,
                    return_tensors='pt')
                input_ids.append(encoded['input_ids'])
                attention_masks.append(encoded['attention_mask'])
        else:
            for sent in corpus:
                encoded = self.tokenizer.encode_plus(
                    sent,
                    add_special_tokens=True,
                    max_length=mxlen,
                    truncation=True,
                    pad_to_max_length=True,
                    return_attention_mask=True,
                    return_tensors='pt')
                input_ids.append(encoded['input_ids'])
                attention_masks.append(encoded['attention_mask'])

        self.input_ids = torch.cat(input_ids, dim=0)
        self.attention_masks = torch.cat(attention_masks, dim=0)

        if labels is not None:
            self.labels = torch.tensor(labels)

    def get_corpus_specifications(self):
        from Korpora import Korpora
        for name, desc in Korpora.corpus_list().items():
            print("{:<40}  {:<}".format(name, desc))

    def build_corpus(self, corpus_name):
        from Korpora import Korpora
        return Korpora.load(corpus_name)

    def build_BERT(self, tokenizer_path):
        from transformers import BertConfig, BertTokenizer
        self.bert_tokenizer_path = tokenizer_path
        self.tokenizer = BertTokenizer.from_pretrained(
            self.bert_tokenizer_path)

    def prepare(self, verbose=False):
        self.build_dataset(verbose)
        self.build_dataloader()
        self.build_optimizer()
        self.build_scheduler()

    def build_scheduler(self):
        from transformers import get_linear_schedule_with_warmup
        self.total_steps = len(self.train_dataloader) * self.epoch
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=0,  # Default value in run_glue.py
            num_training_steps=self.total_steps)

    def build_optimizer(self):
        from transformers import AdamW
        self.optimizer = AdamW(self.bert.parameters(), lr=2e-5, eps=1e-8)

    def build_dataset(self, verbose):
        from torch.utils.data import TensorDataset, random_split
        assert self.input_ids != [] and self.attention_masks != []

        if self.labels is not None:
            self.dataset = TensorDataset(self.input_ids, self.attention_masks,
                                         self.labels)
        else:
            self.dataset = TensorDataset(self.input_ids, self.attention_masks)

        self.train_size = int(self.train_ratio * len(self.dataset))
        self.val_size = len(self.dataset) - self.train_size

        self.train_dataset, self.val_dataset = random_split(
            self.dataset, [self.train_size, self.val_size])
        if verbose:
            print('{:>5,} training samples'.format(self.train_size))
            print('{:>5} validation samples'.format(self.val_size))

    def build_dataloader(self):
        from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
        assert self.train_dataset is not None and self.val_dataset is not None

        self.train_dataloader = DataLoader(
            self.train_dataset,
            sampler=RandomSampler(self.train_dataset),
            batch_size=self.batch,
        )
        self.validation_dataloader = DataLoader(self.val_dataset,
                                                sampler=SequentialSampler(
                                                    self.val_dataset),
                                                batch_size=self.batch)

    def flat_accuracy(self, preds, labels):
        import numpy as np
        pred_flat = np.argmax(preds, axis=1).flatten()
        labels_flat = labels.flatten()
        return np.sum(pred_flat == labels_flat) / len(labels_flat)

    def train(self, verbose=True):
        from tqdm.notebook import tqdm
        import random
        import torch
        import numpy as np

        seed_val = 42

        random.seed(seed_val)
        np.random.seed(seed_val)
        torch.manual_seed(seed_val)
        torch.cuda.manual_seed_all(seed_val)
        training_log = []
        desc_training_loss = None
        self.bert.train()
        self.bert.to(self.device)
        with tqdm(range(0, self.epoch),
                  leave=False,
                  bar_format=
                  "{percentage:2.2f}% {bar} {desc} | {elapsed}>{remaining}"
                  ) as t:
            for epoch_i in range(0, self.epoch):
                t.update()
                total_train_loss = 0
                train_accs = []

                for step, batch in enumerate(self.train_dataloader):
                    desc = "epoch: {:,}/{:,} | step: {:,}/{:,}".format(
                        epoch_i + 1, len(range(0, self.epoch)), step + 1,
                        len(self.train_dataloader))

                    if desc_training_loss is not None:
                        t.set_description_str(desc + " | " +
                                              desc_training_loss)
                    else:
                        t.set_description_str(desc)

                    b_input_ids, b_input_mask, b_labels = map(
                        lambda e: e.to(self.device), batch)

                    self.bert.zero_grad()

                    output = self.bert(b_input_ids,
                                       token_type_ids=None,
                                       attention_mask=b_input_mask,
                                       labels=b_labels)
                    loss = output[0]
                    logits = output[1]

                    total_train_loss += loss
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.bert.parameters(), 1.0)

                    logits = logits.detach().cpu().numpy()
                    label_ids = b_labels.to('cpu').numpy()
                    acc = self.flat_accuracy(logits, label_ids)
                    train_accs.append(acc)
                    self.optimizer.step()
                    self.scheduler.step()
                avg_train_acc = sum(train_accs) / len(train_accs)
                avg_train_loss = total_train_loss / \
                    len(self.train_dataloader)
                desc_training_loss = "mean training loss: {:.2f} / average accuracies:{}".format(
                    avg_train_loss, round(avg_train_acc, 2))
                training_log.append("{:<50}{}".format(desc,
                                                      desc_training_loss))

        if verbose:
            for log in training_log:
                print(log)
        self.trained = True

    def validate(self):
        import torch
        self.bert.eval()
        total_eval_accuracy = 0
        total_eval_loss = 0
        for batch in self.validation_dataloader:
            b_input_ids, b_input_mask, b_labels = map(
                lambda e: e.to(self.device), batch)

            with torch.no_grad():
                self.bert.to(self.device)
                output = self.bert(b_input_ids,
                                   token_type_ids=None,
                                   attention_mask=b_input_mask,
                                   labels=b_labels)
            loss = output[0]
            logits = output[1]

            total_eval_loss += loss.item()
            logits = logits.detach().cpu().numpy()
            label_ids = b_labels.to('cpu').numpy()
            total_eval_accuracy += self.flat_accuracy(logits, label_ids)
        avg_val_accuracy = total_eval_accuracy / \
            len(self.validation_dataloader)
        avg_val_loss = total_eval_loss / len(self.validation_dataloader)

        print("  Validation Loss: {0:.2f}".format(avg_val_loss))
        print("  Validation Accuracy: {0:.2f}".format(avg_val_accuracy))
Exemple #2
0
def train_f1_f2(args, model_f1, model_f2, train_dataset):
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.mini_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    args.num_train_epochs = 1
    t_total = len(train_dataloader
                  ) // args.gradient_accumulation_steps * args.num_train_epochs

    if args.warmup_proportion > 0:
        args.warmup_steps = int(t_total * args.warmup_proportion)

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in list(model_f1.named_parameters()) +
                list(model_f2.named_parameters())
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in list(model_f1.named_parameters()) +
                list(model_f2.named_parameters())
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        [model_f1,
         model_f2], optimizer = amp.initialize([model_f1, model_f2],
                                               optimizer,
                                               opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model_f1 = torch.nn.DataParallel(model_f1)
        model_f2 = torch.nn.DataParallel(model_f2)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model_f1 = torch.nn.parallel.DistributedDataParallel(
            model_f1,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

        model_f2 = torch.nn.parallel.DistributedDataParallel(
            model_f2,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0

    tr_loss, logging_loss = 0.0, 0.0
    model_f1.zero_grad()
    model_f2.zero_grad()

    set_seed(args)
    logger.info("***** train f1 f2 ******")
    logger.info("***** Num examples: {} ********".format(len(train_dataset)))

    for _ in range(1):
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iter(loss=X.XXX, lr=X.XXXXXXXX)",
                              disable=args.local_rank not in [-1, 0])

        for step, batch in enumerate(epoch_iterator):
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model_f1.train()
            model_f2.train()
            batch = tuple(t.to(args.device) for t in batch)

            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "labels": batch[3],
                "label_mask": batch[4]
            }

            outputs1 = model_f1(**inputs)
            loss1 = outputs1

            outputs2 = model_f2(**inputs)
            loss2 = outputs2

            w1 = model_f1.classifier.weight  #[hidden_size, num_labels]
            w2 = model_f2.classifier.weight.transpose(
                -1, -2)  #[num_labels, hidden_size]

            norm_term = torch.norm(torch.matmul(w1, w2))

            loss = loss1 + loss2 + args.alpha * norm_term

            if args.n_gpu > 1:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                epoch_iterator.set_description(
                    'Iter (loss=%5.3f) lr=%9.7f' %
                    (loss.item(), scheduler.get_lr()[0]))
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model_f1.parameters(),
                                                   args.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(model_f2.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model_f1.zero_grad()
                model_f2.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics

                    tb_writer.add_scalar("f1_f2_lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("f1_f2_loss",
                                         (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return model_f1, model_f2
Exemple #3
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 1
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    # Added here for reproductibility
    set_seed(args)

    for e in train_iterator:
        if args.local_rank != -1:
            train_dataloader.sampler.set_epoch(e)
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)

            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "start_positions": batch[3],
                "end_positions": batch[4],
            }

            if args.model_type in [
                    "xlm", "roberta", "distilbert", "camembert", "bart"
            ]:
                del inputs["token_type_ids"]

            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
                if args.version_2_with_negative:
                    inputs.update({"is_impossible": batch[7]})
                if hasattr(model, "config") and hasattr(
                        model.config, "lang2id"):
                    inputs.update({
                        "langs":
                        (torch.ones(batch[0].shape, dtype=torch.int64) *
                         args.lang_id).to(args.device)
                    })

            outputs = model(**inputs)
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # Log metrics
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.local_rank == -1 and args.evaluate_during_training:
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                # Save model checkpoint
                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    # Take care of distributed/parallel training
                    model_to_save = model.module if hasattr(
                        model, "module") else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemple #4
0
class Replay:

    def __init__(self, device, **kwargs):
        self.lr = kwargs.get('lr', 3e-5)
        self.write_prob = kwargs.get('write_prob')
        self.replay_rate = kwargs.get('replay_rate')
        self.replay_every = kwargs.get('replay_every')
        self.device = device

        self.model = TransformerClsModel(model_name=kwargs.get('model'),
                                         n_classes=1,
                                         max_length=kwargs.get('max_length'),
                                         device=device)
        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=3)
        logger.info('Loaded {} as the model'.format(self.model.__class__.__name__))

        params = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = AdamW(params, lr=self.lr)
        self.loss_fn = nn.BCEWithLogitsLoss()

    def save_model(self, model_path):
        checkpoint = self.model.state_dict()
        torch.save(checkpoint, model_path)

    def load_model(self, model_path):
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint)

    def train(self, dataloader, n_epochs, log_freq):

        self.model.train()

        for epoch in range(n_epochs):
            all_losses, all_predictions, all_labels = [], [], []
            iter = 0

            for text, label, candidates in dataloader:
                replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(text, label,
                                                                                                         candidates)

                input_dict = self.model.encode_text(list(zip(replicated_text, replicated_relations)))
                output = self.model(input_dict)
                targets = torch.tensor(ranking_label).float().unsqueeze(1).to(self.device)
                loss = self.loss_fn(output, targets)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                mini_batch_size = len(label)
                replay_freq = self.replay_every // mini_batch_size
                replay_steps = int(self.replay_every * self.replay_rate / mini_batch_size)

                if self.replay_rate != 0 and (iter + 1) % replay_freq == 0:
                    self.optimizer.zero_grad()
                    for _ in range(replay_steps):
                        ref_text, ref_label, ref_candidates = self.memory.read_batch(batch_size=mini_batch_size)
                        replicated_ref_text, replicated_ref_relations, ref_ranking_label = datasets.utils.replicate_rel_data(ref_text, ref_label, ref_candidates)
                        ref_input_dict = self.model.encode_text(list(zip(replicated_ref_text, replicated_ref_relations)))
                        ref_output = self.model(ref_input_dict)
                        ref_targets = torch.tensor(ref_ranking_label).float().unsqueeze(1).to(self.device)
                        ref_loss = self.loss_fn(ref_output, ref_targets)
                        ref_loss.backward()

                    params = [p for p in self.model.parameters() if p.requires_grad]
                    torch.nn.utils.clip_grad_norm(params, 10)
                    self.optimizer.step()

                loss = loss.item()
                pred, true_labels = models.utils.make_rel_prediction(output, ranking_label)
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(true_labels.tolist())
                iter += 1
                self.memory.write_batch(text, label, candidates)

                if iter % log_freq == 0:
                    acc = models.utils.calculate_accuracy(all_predictions, all_labels)
                    logger.info(
                        'Epoch {} metrics: Loss = {:.4f}, accuracy = {:.4f}'.format(epoch + 1, np.mean(all_losses), acc))
                    all_losses, all_predictions, all_labels = [], [], []

    def evaluate(self, dataloader):
        all_losses, all_predictions, all_labels = [], [], []

        self.model.eval()

        for text, label, candidates in dataloader:
            replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(text, label,
                                                                                                     candidates)

            with torch.no_grad():
                input_dict = self.model.encode_text(list(zip(replicated_text, replicated_relations)))
                output = self.model(input_dict)

            pred, true_labels = models.utils.make_rel_prediction(output, ranking_label)
            all_predictions.extend(pred.tolist())
            all_labels.extend(true_labels.tolist())

        acc = models.utils.calculate_accuracy(all_predictions, all_labels)

        return acc

    def training(self, train_datasets, **kwargs):
        n_epochs = kwargs.get('n_epochs', 1)
        log_freq = kwargs.get('log_freq', 20)
        mini_batch_size = kwargs.get('mini_batch_size')
        train_dataset = data.ConcatDataset(train_datasets)
        train_dataloader = data.DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=False,
                                           collate_fn=datasets.utils.rel_encode)
        self.train(dataloader=train_dataloader, n_epochs=n_epochs, log_freq=log_freq)

    def testing(self, test_dataset, **kwargs):
        mini_batch_size = kwargs.get('mini_batch_size')
        test_dataloader = data.DataLoader(test_dataset, batch_size=mini_batch_size, shuffle=False,
                                          collate_fn=datasets.utils.rel_encode)
        acc = self.evaluate(dataloader=test_dataloader)
        logger.info('Overall test metrics: Accuracy = {:.4f}'.format(acc))
        return acc
Exemple #5
0
def train():
    writer = SummaryWriter(comment="Relation")
    modelDir = writer.log_dir.replace("runs", "models")
    epochs = 20
    device = "cuda"
    dataset = RelationDataset("albert-base-v2", device="cpu")
    dataloader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn_padd)
    model = AlbertForRelation.from_pretrained(
        "albert-base-v2",
        num_rel_labels=len(relationTypes),
    )
    model.resize_token_embeddings(len(dataset.dataset.tokenizer))
    model.to(device)
    optim = AdamW(
        [
            {"params": model.albert.parameters(), "lr": 1e-4},
            {
                "params": model.classifier.parameters(),
                "lr": 1e-3,
            },
        ]
    )
    scheduler = get_linear_schedule_with_warmup(optim, 100, epochs * 10000 / 32)

    iTot = 0
    for epoch in range(epochs):
        i = 0
        lossesTrain = []
        lossesVal = []
        for (
            input_ids,
            token_type_ids,
            attention_mask,
            rel_label,
            e1_index,
            e2_index,
        ) in dataloader:
            if i % 5 != 0:
                model.train()
                loss, acc = model(
                    input_ids.to(device),
                    token_type_ids.to(device),
                    attention_mask.to(device),
                    rel_label.to(device),
                    e1_index.to(device),
                    e2_index.to(device),
                )
                loss.backward()
                optim.step()
                scheduler.step()
                optim.zero_grad()
                lossesTrain.append(loss.item())
                writer.add_scalar("lossRel/Train", lossesTrain[-1], iTot)
                writer.add_scalar("accRel/Train", acc.item(), iTot)
            else:
                with torch.no_grad():
                    model.eval()
                    loss, acc = model(
                        input_ids.to(device),
                        token_type_ids.to(device),
                        attention_mask.to(device),
                        rel_label.to(device),
                        e1_index.to(device),
                        e2_index.to(device),
                    )
                    lossesVal.append(loss.item())
                    writer.add_scalar("accRel/Eval", acc.item(), iTot)
                    writer.add_scalar("lossRel/Eval", lossesVal[-1], iTot)
            if iTot % 20 == 0:
                for (i2, lr) in enumerate(scheduler.get_lr()):
                    writer.add_scalar("lr/" + str(i2), lr, iTot)
            print(epoch, i)
            i += 1
            iTot += 1
        model.save_pretrained(modelDir + "/" + str(epoch))
        dataset.dataset.tokenizer.save_pretrained(modelDir + "/" + str(epoch))
Exemple #6
0
def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [{
        "params": [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay":
        args.weight_decay
    }, {
        "params": [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        "weight_decay":
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                inputs["token_type_ids"]: batch[2] if args.model_type in [
                    "bert", "xlnet"
                ] else None  # XLM and RoBERTa don"t use segment_ids

            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in pytorch-transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            print("loss", loss)
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                scheduler.step()  # Update learning rate schedule
                optimizer.step()
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results, _ = evaluate(args,
                                              model,
                                              tokenizer,
                                              labels,
                                              pad_token_label_id,
                                              mode="dev")
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, "module"
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
    def train(self):

        train_sampler = RandomSampler(self.train_dataset)

        train_dataloader = DataLoader(self.train_dataset,
                                      sampler=train_sampler,
                                      batch_size=self.args.train_batch_size)

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            self.args.num_train_epochs = self.args.max_steps // (len(
                train_dataloader) // self.args.gradient_accumulation_steps) + 1
        else:
            t_total = len(
                train_dataloader
            ) // self.args.gradient_accumulation_steps * self.args.num_train_epochs

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in self.model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            self.args.weight_decay
        }, {
            'params': [
                p for n, p in self.model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=self.args.learning_rate,
                          eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.args.warmup_steps,
            num_training_steps=t_total)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(self.train_dataset))
        logger.info("  Num Epochs = %d", self.args.num_train_epochs)
        logger.info("  Total train batch size = %d",
                    self.args.train_batch_size)
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)
        logger.info("  Logging steps = %d", self.args.logging_steps)
        logger.info("  Save steps = %d", self.args.save_steps)

        witer = SummaryWriter(logdir=self.args.model_dir,
                              comment="classification")

        global_step = 0
        tr_loss = 0.0
        best_mean_precision = 0.0
        self.model.zero_grad()

        train_iterator = trange(int(self.args.num_train_epochs), desc="Epoch")
        set_seed(self.args)

        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for step, batch in enumerate(epoch_iterator):
                self.model.train()
                batch = tuple(t.to(self.device) for t in batch)  # GPU or CPU

                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'token_type_ids': batch[2],
                    'e1_mask': batch[3],
                    'e2_mask': batch[4],
                    'o_label': batch[5],
                    'm_label': batch[6]
                }
                outputs = self.model(**inputs)
                loss = outputs[0]
                # eval_results = self.evaluate('dev')
                # print(1)

                if self.args.gradient_accumulation_steps > 1:
                    loss = loss / self.args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                if (step + 1) % self.args.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.args.max_grad_norm)

                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    self.model.zero_grad()
                    global_step += 1

                    if self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0:
                        eval_results = self.evaluate('dev')

                        witer.add_scalar("Test/loss",
                                         eval_results.get("loss", 0),
                                         global_step)
                        # witer.add_scalar("Test/total", eval_results.get('total', {}).get("precision", 0), global_step)
                        # witer.add_scalar("Test/total", eval_results.get('total', {}).get("recall", 0), global_step)
                        # witer.add_scalar("Test/total", eval_results.get('total', {}).get("f1", 0), global_step)

                        witer.add_scalar(
                            "Test/mean",
                            eval_results.get('mean',
                                             {}).get("mean_precision", 0),
                            global_step)
                        witer.add_scalar(
                            "Test/mean",
                            eval_results.get('mean', {}).get("mean_recall", 0),
                            global_step)
                        witer.add_scalar(
                            "Test/mean",
                            eval_results.get('mean',
                                             {}).get("mean_f1-score", 0),
                            global_step)

                        witer.add_scalar(
                            "Test/sum",
                            eval_results.get('sum',
                                             {}).get("sum_precision", 0),
                            global_step)
                        witer.add_scalar(
                            "Test/sum",
                            eval_results.get('sum', {}).get("sum_recall", 0),
                            global_step)
                        witer.add_scalar(
                            "Test/sum",
                            eval_results.get('sum', {}).get("sum_f1-score", 0),
                            global_step)
                        # labels = ["0", "10001", "10002"]
                        # for k in labels:
                        #     witer.add_scalar("Test/{}".format(k), eval_results.get(k, {}).get("precision", 0), global_step)
                        #     witer.add_scalar("Test/{}".format(k), eval_results.get(k, {}).get("recall", 0), global_step)
                        #     witer.add_scalar("Test/{}".format(k), eval_results.get(k, {}).get("f1", 0), global_step)

                        # if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
                        #     if eval_results.get("total", {}).get("f1", 0) > best_mean_precision:
                        #         best_mean_precision = eval_results.get("total", {}).get("f1", 0)
                        #         self.save_model()

                        if eval_results.get("mean", {}).get(
                                "mean_f1-score", 0) >= best_mean_precision:
                            best_mean_precision = eval_results.get(
                                "mean", {}).get("mean_f1-score", 0)
                            self.save_model()

                if 0 < self.args.max_steps < global_step:
                    epoch_iterator.close()
                    break

                witer.add_scalar("Train/loss", tr_loss / global_step,
                                 global_step)
                lr = scheduler.get_lr()[-1]
                witer.add_scalar("Train/lr", lr, global_step)

            eval_results = self.evaluate('dev')
            witer.add_scalar("Test/loss", eval_results.get("loss", 0),
                             global_step)
            # witer.add_scalar("Test/total", eval_results.get('total', {}).get("precision", 0), global_step)
            # witer.add_scalar("Test/total", eval_results.get('total', {}).get("recall", 0), global_step)
            # witer.add_scalar("Test/total", eval_results.get('total', {}).get("f1", 0), global_step)

            witer.add_scalar(
                "Test/mean",
                eval_results.get('mean', {}).get("mean_precision", 0),
                global_step)
            witer.add_scalar(
                "Test/mean",
                eval_results.get('mean', {}).get("mean_recall", 0),
                global_step)
            witer.add_scalar(
                "Test/mean",
                eval_results.get('mean', {}).get("mean_f1-score", 0),
                global_step)

            witer.add_scalar(
                "Test/sum",
                eval_results.get('sum', {}).get("sum_precision", 0),
                global_step)
            witer.add_scalar("Test/sum",
                             eval_results.get('sum', {}).get("sum_recall", 0),
                             global_step)
            witer.add_scalar(
                "Test/sum",
                eval_results.get('sum', {}).get("sum_f1-score", 0),
                global_step)

            if eval_results.get("mean", {}).get("mean_f1-score",
                                                0) >= best_mean_precision:
                best_mean_precision = eval_results.get("mean", {}).get(
                    "mean_f1-score", 0)
                self.save_model()

            if 0 < self.args.max_steps < global_step:
                train_iterator.close()
                break

        return global_step, tr_loss / global_step
Exemple #8
0
def train(weight,
          model_num,
          model,
          train_dataloader,
          validation_dataloader,
          filepath='',
          lr=2e-5,
          EPOCHS=10,
          BATCH_SIZE=1):
    total_t0 = time.time()
    training_stats = []
    model = model.to(DEVICE)
    optimizer = AdamW(model.parameters(), lr=lr)

    weight = weight.to(DEVICE)
    loss_func = nn.NLLLoss(weight)
    loss_real = nn.NLLLoss()
    softmax = nn.LogSoftmax(dim=1)

    for epoch_num in range(EPOCHS):
        t0 = time.time()
        model.train()
        total_train_loss = 0
        for step_num, batch_data in enumerate(train_dataloader):
            input_ids, attention_masks, anger, fear, joy, sadness, vec, intensity = tuple(
                t for t in batch_data)

            ##ordinal
            o1 = torch.tensor((intensity.numpy() > 0).astype(int))
            o2 = torch.tensor((intensity.numpy() > 1).astype(int))
            o3 = torch.tensor((intensity.numpy() > 2).astype(int))

            if model_num == 1:
                o = o1
            if model_num == 2:
                o = o2
            if model_num == 3:
                o = o3
            ###

            input_ids = input_ids.to(DEVICE)
            attention_masks = attention_masks.to(DEVICE)
            anger = anger.to(DEVICE)
            fear = fear.to(DEVICE)
            joy = joy.to(DEVICE)
            sadness = sadness.to(DEVICE)
            intensity = intensity.to(DEVICE)
            vec = vec.to(DEVICE)
            o = o.to(DEVICE)

            model.zero_grad()

            probas = model(input_ids, attention_masks)
            loss = loss_func(probas, o)

            total_train_loss += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            # scheduler.step()
            print('Epoch: ', epoch_num + 1)
            print("\r" + "{0}/{1} loss: {2} ".format(
                step_num,
                len(train_dataloader) / BATCH_SIZE, total_train_loss /
                (step_num + 1)))
        avg_train_loss = total_train_loss / len(train_dataloader)

        #model_save_name = filepath + '_epoch_' + str(epoch_num) + '_lr_' + str(lr) + '_' + str(model_num) + '.pt'
        #torch.save(model.state_dict(), model_save_name)
        training_time = format_time(time.time() - t0)
        model.eval()

        total_pearson = 0
        total_kappa = 0
        total_eval_loss_r = 0
        total_eval_loss_w = 0

        for batch_data in validation_dataloader:
            input_ids, attention_masks, anger, fear, joy, sadness, vec, intensity = tuple(
                t for t in batch_data)
            ##ordinal
            o1 = torch.tensor((intensity.numpy() > 0).astype(int))
            o2 = torch.tensor((intensity.numpy() > 1).astype(int))
            o3 = torch.tensor((intensity.numpy() > 2).astype(int))

            if model_num == 1:
                o = o1
            elif model_num == 2:
                o = o2
            else:
                o = o3
            ###
            input_ids = input_ids.to(DEVICE)
            attention_masks = attention_masks.to(DEVICE)
            anger = anger.to(DEVICE)
            fear = fear.to(DEVICE)
            joy = joy.to(DEVICE)
            sadness = sadness.to(DEVICE)
            intensity = intensity.to(DEVICE)
            vec = vec.to(DEVICE)
            o = o.to(DEVICE)

            with torch.no_grad():
                probas = model(input_ids, attention_masks)
                output = torch.max(probas, 1)[1]

                lossr = loss_real(probas, o)
                lossw = loss_func(probas, o)
                # Accumulate the validation loss.
                total_eval_loss_r += lossr.item()
                total_eval_loss_w += lossw.item()

                output = output.detach().cpu()
                o = o.to('cpu')

                # Calculate the accuracy for this batch of test sentences, and
                # accumulate it over all batches.

                pear, _, kappa, _ = evaluate_PerEmotion(o, output)

                total_pearson += pear
                total_kappa += kappa

        # Report the final accuracy for this validation run.
        avg_pearson = total_pearson / len(validation_dataloader)
        avg_kappa = total_kappa / len(validation_dataloader)

        # Calculate the average loss over all of the batches.
        avg_val_loss_r = total_eval_loss_r / len(validation_dataloader)
        avg_val_loss_w = total_eval_loss_w / len(validation_dataloader)

        val_time = format_time(time.time() - t0)

        # Record all statistics from this epoch.
        training_stats.append({
            'epoch': epoch_num + 1,
            'Training Loss on 1 ordinal': avg_train_loss,
            'Valid. Loss on 1 ordinal, real': avg_val_loss_r,
            'Valid. Loss on 1 ordinal, weighted': avg_val_loss_w,
            'Pearson on 1 ordinal': avg_pearson,
            'Kappa on 1 ordinal': avg_kappa,
            'Learning Rate': lr,
            'Training Time': training_time,
            'Validation Time': val_time
        })

        print(training_stats)

    return training_stats, model
Exemple #9
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0],
                              position=0,
                              leave=True)
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'labels': batch[3]
            }
            if args.model_type != 'distilbert':
                inputs['token_type_ids'] = batch[2] if args.model_type in [
                    'bert', 'xlnet'
                ] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
            if args.use_length:
                if args.model_type == 'roberta':
                    inputs['lengths'] = (inputs['attention_mask'] -
                                         inputs['token_type_ids']).sum(
                                             dim=1, keepdim=True).float()
                if args.model_type == 'xlnet':
                    mask = inputs['attention_mask'] - inputs['token_type_ids']
                    mask[mask < 0] = 0
                    inputs['lengths'] = mask.sum(dim=1, keepdim=True).float()
            if args.use_matchings:
                inputs['matchings'] = batch[4]
            if args.join_embeddings:
                if args.model_type == 'roberta':
                    inputs['embeddings_mask'] = inputs[
                        'attention_mask'] - inputs['token_type_ids']
                if args.model_type == 'xlnet':
                    mask = inputs['attention_mask'] - inputs['token_type_ids']
                    mask[mask < 0] = 0
                    inputs['embeddings_mask'] = mask
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            eval_key = 'eval_{}'.format(key)
                            logs[eval_key] = value

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs['learning_rate'] = learning_rate_scalar
                    logs['loss'] = loss_scalar
                    logging_loss = tr_loss

                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    print(json.dumps({**logs, **{'step': global_step}}))

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
    def run(self):

        torch.cuda.empty_cache()

        # raw data polling and pretreatment datas
        self.generate_data()

        # generate save model directory
        self.generate_model_directory()

        if self.tokenizer_type != '':
            # generate corpus by Okt konlpy
            # self.generate_custom_morphs(self.list_memo)

            # generate tokenizer model
            self.generate_custom_vocab()

        tokenizer = None
        if self.tokenizer_type == '':
            # base tokenizer
            tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased",
                                                            lowercase=True,
                                                            strip_accents=False,
                                                            local_files_only=False)
        else:
            # word piece tokenizer
            tokenizer = DistilBertTokenizerFast.from_pretrained(self.vocab_root_dir + self.vocab_dir,
                                                                strip_accents=False,
                                                                lowercase=True)

        self.setPrint('Load Customer Vocab size : {}'.format(tokenizer.vocab_size))
        # tokenizer Loading check
        # tokenized_input_for_pytorch = tokenizer_for_load("i am very happy now", return_tensors="pt")
        # encoded_text = tokenizer("전화 통화가 정상적으로 안됨", return_tensors="pt")
        # self.setPrint("Tokens Text List: {}".format(
        #     [tokenizer.convert_ids_to_tokens(s) for s in encoded_text['input_ids'].tolist()[0]]))
        # self.setPrint("Tokens IDX  List: {}".format(encoded_text['input_ids'].tolist()[0]))
        # self.setPrint("Tokens Mask List: {}".format(encoded_text['attention_mask'].tolist()[0]))

        # transformed train data
        encoded_data_train = tokenizer.batch_encode_plus(
            self.Train_Data_X,
            add_special_tokens=True,
            return_attention_mask=True,
            # padding='longest',
            padding=True,
            max_length=256,
            return_tensors='pt',
            truncation=True
        )
        # transformed validation data
        encoded_data_val = tokenizer.batch_encode_plus(
            self.Test_Data_X,
            add_special_tokens=True,
            return_attention_mask=True,
            # padding='longest',
            padding=True,
            max_length=256,
            return_tensors='pt',
            truncation=True
        )

        input_ids_train = encoded_data_train['input_ids']
        attention_masks_train = encoded_data_train['attention_mask']
        labels_train = torch.tensor(self.Train_Data_Y)

        input_ids_test = encoded_data_val['input_ids']
        attention_masks_test = encoded_data_val['attention_mask']
        labels_test = torch.tensor(self.Test_Data_Y)

        dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)
        dataset_test = TensorDataset(input_ids_test, attention_masks_test, labels_test)

        # local_files_only = True
        self.model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased",
                                                                         num_labels=len(self.label_index),
                                                                         output_attentions=False,
                                                                         output_hidden_states=False,
                                                                         local_files_only=False).to(self.device)

        # dataLoader
        dataloader_train = DataLoader(dataset_train,
                                      sampler=RandomSampler(dataset_train),
                                      batch_size=self.batch_size,
                                      drop_last=True)

        dataloader_test = DataLoader(dataset_test,
                                     sampler=RandomSampler(dataset_test),
                                     batch_size=self.batch_size)

        optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, eps=1e-8)
        scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,
                                                    num_warmup_steps=0,
                                                    num_training_steps=len(dataloader_train) * self.epoch)

        # scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer=optimizer,
        #                                                                num_warmup_steps=0,
        #                                                                num_training_steps=len(dataloader_train) * self.epoch)
        # for loss f1 graph
        total_train_loss = np.array([0.0000] * self.epoch)
        total_val_loss = np.array([0.0000] * self.epoch)
        total_score = np.array([0.0000] * self.epoch)

        # Training start
        for epoch in range(1, self.epoch + 1):
            self.setPrint('Start of Epoch {}'.format(epoch))
            self.model.train()
            loss_train_total = 0

            for idx, batch in enumerate(dataloader_train):
                self.model.zero_grad()
                batch = tuple(b.to(self.device) for b in batch)
                inputs = {'input_ids': batch[0],
                          'attention_mask': batch[1],
                          'labels': batch[2],
                          }
                outputs = self.model(**inputs)
                loss = outputs[0]
                loss_train_total += loss.item()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()

                if idx % 100 == 0:
                    self.setPrint('[{}]Epoch {}/{} training_loss : {:.4f}'.format(epoch, idx, len(dataloader_train),
                                                                                  loss.item() / len(batch)))
                # gpu memory reset
                batch = None
                torch.cuda.empty_cache()

            # model save
            torch.save(self.model.state_dict(),
                       self.model_root_dir + self.model_dir + 'BERT_dict_epoch_{}.model'.format(epoch))
            self.setPrint('Save fine_tuned_BERT_epoch_{}.model'.format(epoch))
            self.setPrint('\nEnd of Epoch {}'.format(epoch))

            loss_train_avg = loss_train_total / len(dataloader_train)
            self.setPrint('[{}] Epoch Training loss: {:.4f}'.format(epoch, loss_train_avg))
            total_train_loss[epoch - 1] = round(loss_train_avg, 4)

            val_loss, predictions, true_vals = self.evaluate(dataloader_test)
            val_f1 = self.f1_score_func(predictions, true_vals)

            total_score[epoch - 1] = round(val_f1, 4)
            total_val_loss[epoch - 1] = round(val_loss, 4)

            self.setPrint('[{}] Validation loss: {:.4f}'.format(epoch, val_loss))
            self.setPrint('[{}] F1 Score : {:.4f}'.format(epoch, val_f1))

        # generate graph
        self.generate_graph(total_train_loss, total_val_loss, total_score)
Exemple #11
0
def finetune(model, tokenizer, dataset, models_folder="trained_models", batch_size=1, epochs=5, learning_rate=0.0001, warmup_steps=5000, max_seq_len=400):
    poem_loader = DataLoader(PoemDataset(dataset), batch_size=1, shuffle=True)

    device = set_device()
    model = model.to(device)
    model.train()
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1)
    proc_seq_count = 0
    sum_loss = 0.0
    batch_count = 0

    tmp_poems_tens = None
    if not os.path.exists(models_folder):
        os.mkdir(models_folder)

    for epoch in range(epochs):
        print(f"EPOCH {epoch} started" + '=' * 30)

        for idx,poem in enumerate(poem_loader):

            #"Fit as many poem sequences into max_seq_len sequence as possible" logic start ####
            poem_tens = torch.tensor(tokenizer.encode(poem[0])).unsqueeze(0).to(device)
            #Skip sample from dataset if it is longer than max_seq_len
            if poem_tens.size()[1] > max_seq_len:
                continue

            #First poem in seq
            if not torch.is_tensor(tmp_poems_tens):
                tmp_poems_tens = poem_tens
                continue
            else:
                #The next poem does not fit in so we process the sequence and leave the last poem
                #as the start for next sequence
                if tmp_poems_tens.size()[1] + poem_tens.size()[1] > max_seq_len:
                    work_poems_tens = tmp_poems_tens
                    tmp_poems_tens = poem_tens
                else:
                    #Add the poem to sequence, continue and try to add more
                    tmp_poems_tens = torch.cat([tmp_poems_tens, poem_tens[:,1:]], dim=1)
                    continue
            #Sequence ready, pass through model

            outputs = model(work_poems_tens, labels=work_poems_tens)
            loss, logits = outputs[:2]
            loss.backward() #auto differentiation ~ lookup
            sum_loss = sum_loss + loss.detach().data

            proc_seq_count = proc_seq_count + 1
            if proc_seq_count == batch_size:
                proc_seq_count = 0
                batch_count += 1
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                model.zero_grad()

            if batch_count == 100:
                print(f"sum loss {sum_loss}")
                batch_count = 0
                sum_loss = 0.0

        # Store the model after each epoch to compare the performance of them
        torch.save(model.state_dict(), os.path.join(models_folder, f"gpt2_poet_epoch_{epoch}.pt"))
Exemple #12
0
def main():
    parser = argparse.ArgumentParser()
    model_name = './pretrained/hfl_chinese_roberta_wwm_ext'
    # model_name ='./pretrained/hfl_chinese_wwm_ext'
    ## Required parameters
    parser.add_argument("--data_dir",
                        default='.',
                        type=str,
                        # required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--bert_config_file",
                        # default='./pretrained/hfl_chinese_wwm_ext',
                        # default='bert-base-chinese',
                        # default='hfl/rbt3',
                        default=model_name,
                        type=str,
                        # required=True,
                        help="The config json file corresponding to the pre-trained BERT model. \n"
                             "This specifies the model architecture.")
    parser.add_argument("--task_name",
                        default='c3',
                        type=str,
                        # required=True,
                        help="The name of the task to train.")
    # parser.add_argument("--vocab_file",
    #                     default=None,
    #                     type=str,
    #                     required=True,
    #                     help="The vocabulary file that the BERT model was trained on.")
    parser.add_argument("--output_dir",
                        default=model_name.split('/')[-1]+'_'+t,
                        type=str,
                        # required=True,
                        help="The output directory where the model checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--init_checkpoint",
                        default=None,
                        type=str,
                        help="Initial checkpoint (usually from a pre-trained BERT model).")
    parser.add_argument("--do_lower_case",
                        default=False,
                        action='store_true',
                        help="Whether to lower case the input text. True for uncased models, False for cased models.")
    parser.add_argument("--max_seq_length",
                        default=512,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default=True,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=12,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=12,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=10,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--save_checkpoints_steps",
                        default=1000,
                        type=int,
                        help="How often to save the model checkpoint.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=2,
                        help="Number of updates steps to accumualte before performing a backward/update pass.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")

    args = parser.parse_args()

    processors = {
        "c3": c3Processor,
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logging.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    # bert_config = BertConfig.from_json_file(args.bert_config_file)

    if args.max_seq_length > 512:
        raise ValueError(
            "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format(
                args.max_seq_length, 512))

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        os.makedirs(args.output_dir + args.output_dir[-1], exist_ok=True)
        # if args.do_train:
        #     raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    else:
        os.makedirs(args.output_dir, exist_ok=True)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    label_list = processor.get_labels()

    # tokenizer = tokenization.FullTokenizer(
    #     vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)

    tokenizer = BertTokenizer.from_pretrained(args.bert_config_file)
    train_examples = None
    num_train_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_steps = int(
            len(
                train_examples) / n_class / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)

    model = BertForMultipleChoice.from_pretrained(args.bert_config_file)
    # model = BSC.from_pretrained(args.bert_config_file,num_rel=1 if n_class > 1 else len(label_list))
    # if args.init_checkpoint is not None:
    #     model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
    model.to(device)

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)



    global_step = 0

    if args.do_eval:
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer)

        input_ids = []
        input_mask = []
        segment_ids = []
        label_id = []

        for f in eval_features:
            input_ids.append([])
            input_mask.append([])
            segment_ids.append([])
            for i in range(n_class):
                input_ids[-1].append(f[i].input_ids)
                input_mask[-1].append(f[i].input_mask)
                segment_ids[-1].append(f[i].segment_ids)
            label_id.append(f[0].label_id)

        all_input_ids = torch.tensor(input_ids, dtype=torch.long)
        all_input_mask = torch.tensor(input_mask, dtype=torch.long)
        all_segment_ids = torch.tensor(segment_ids, dtype=torch.long)
        all_label_ids = torch.tensor(label_id, dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            eval_sampler = SequentialSampler(eval_data)
        else:
            eval_sampler = DistributedSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

    if args.do_train:
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_parameters = [
            {'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
            {'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
        ]
        #
        # optimizer = BERTAdam(optimizer_parameters,
        #                      lr=args.learning_rate,
        #                      warmup=args.warmup_proportion,
        #                      t_total=num_train_steps)

        # no_decay = ['bias', 'LayerNorm.weight']
        # optimizer_grouped_parameters = [
        #     {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        #      'weight_decay': args.weight_decay},
        #     {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        # ]
        optimizer = AdamW(optimizer_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(optimizer, int(args.warmup_proportion * num_train_steps),
                                                    num_train_steps)
        best_accuracy = 0

        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer)
        logging.info("***** Running training *****")
        logging.info("  Num examples = %d", len(train_examples))
        logging.info("  Batch size = %d", args.train_batch_size)
        logging.info("  Num steps = %d", num_train_steps)

        input_ids = []
        input_mask = []
        segment_ids = []
        label_id = []
        for f in train_features:
            input_ids.append([])
            input_mask.append([])
            segment_ids.append([])
            for i in range(n_class):
                input_ids[-1].append(f[i].input_ids)
                input_mask[-1].append(f[i].input_mask)
                segment_ids[-1].append(f[i].segment_ids)
            label_id.append(f[0].label_id)

        all_input_ids = torch.tensor(input_ids, dtype=torch.long)
        all_input_mask = torch.tensor(input_mask, dtype=torch.long)
        all_segment_ids = torch.tensor(segment_ids, dtype=torch.long)
        all_label_ids = torch.tensor(label_id, dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            model.train()
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                # print(input_ids.size(), input_mask.size(), segment_ids.size(), label_ids.size())
                # loss, _ = model(input_ids=input_ids.view(-1,args.max_seq_length), token_type_ids=segment_ids.view(-1,args.max_seq_length),
                #                               attention_mask=input_mask.view(-1,args.max_seq_length), labels=label_ids.view(-1))
                # loss, _ = model(input_ids, segment_ids, input_mask, label_ids, n_class)
                loss, _ = model(input_ids=input_ids,
                                token_type_ids=segment_ids,
                                attention_mask=input_mask, labels=label_ids)

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                loss.backward()
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()  # We have accumulated enought gradients
                    scheduler.step()
                    optimizer.zero_grad()
                    model.zero_grad()
                    global_step += 1

            model.eval()
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0
            logits_all = []
            for st, batch in enumerate(tqdm(eval_dataloader, desc="Iteration")):
                input_ids, input_mask, segment_ids, label_ids = batch
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)

                with torch.no_grad():
                    # tmp_eval_loss, logits = model(input_ids=input_ids.view(-1,args.max_seq_length), token_type_ids=segment_ids.view(-1,args.max_seq_length),
                    #                           attention_mask=input_mask.view(-1,args.max_seq_length), labels=label_ids.view(-1))
                    # tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids, n_class)
                    tmp_eval_loss, logits = model(input_ids=input_ids,
                                                  token_type_ids=segment_ids,
                                                  attention_mask=input_mask,
                                                  labels=label_ids)

                logits = logits.detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()
                for i in range(len(logits)):
                    logits_all += [logits[i]]

                tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1))

                eval_loss += tmp_eval_loss.mean().item()
                eval_accuracy += tmp_eval_accuracy

                nb_eval_examples += input_ids.size(0)
                nb_eval_steps += 1

            eval_loss = eval_loss / nb_eval_steps
            eval_accuracy = eval_accuracy / nb_eval_examples

            if args.do_train:
                result = {'eval_loss': eval_loss,
                          'eval_accuracy': eval_accuracy,
                          'global_step': global_step,
                          'loss': tr_loss / nb_tr_steps}
            else:
                result = {'eval_loss': eval_loss,
                          'eval_accuracy': eval_accuracy}

            logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logging.info("  %s = %s", key, str(result[key]))

            if eval_accuracy >= best_accuracy:
                torch.save(model.state_dict(), os.path.join(args.output_dir, "model_best.pt"))
                best_accuracy = eval_accuracy

        model.load_state_dict(torch.load(os.path.join(args.output_dir, "model_best.pt")))
        torch.save(model.state_dict(), os.path.join(args.output_dir, "model.pt"))

    model.load_state_dict(torch.load(os.path.join(args.output_dir, "model.pt")))

    if args.do_eval:
        logging.info("***** Running evaluation *****")
        logging.info("  Num examples = %d", len(eval_examples))
        logging.info("  Batch size = %d", args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        logits_all = []
        for st, batch in enumerate(tqdm(eval_dataloader, desc="Iteration")):
            input_ids, input_mask, segment_ids, label_ids = batch
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                # tmp_eval_loss, logits = model(input_ids=input_ids.view(-1,args.max_seq_length), token_type_ids=segment_ids.view(-1,args.max_seq_length),
                #                               attention_mask=input_mask.view(-1,args.max_seq_length), labels=label_ids.view(-1))
                # tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids, n_class)
                tmp_eval_loss, logits = model(input_ids=input_ids,
                                              token_type_ids=segment_ids,
                                              attention_mask=input_mask,
                                              labels=label_ids)
            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            for i in range(len(logits)):
                logits_all += [logits[i]]

            tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1))

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples

        if args.do_train:
            result = {'eval_loss': eval_loss,
                      'eval_accuracy': eval_accuracy,
                      'global_step': global_step,
                      'loss': tr_loss / nb_tr_steps}
        else:
            result = {'eval_loss': eval_loss,
                      'eval_accuracy': eval_accuracy}

        output_eval_file = os.path.join(args.output_dir, "eval_results_dev.txt")
        with open(output_eval_file, "w") as writer:
            logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
        output_eval_file = os.path.join(args.output_dir, "logits_dev.txt")
        with open(output_eval_file, "w") as f:
            for i in range(len(logits_all)):
                for j in range(len(logits_all[i])):
                    f.write(str(logits_all[i][j]))
                    if j == len(logits_all[i]) - 1:
                        f.write("\n")
                    else:
                        f.write(" ")

        eval_examples = processor.get_test_examples(args.data_dir)
        eval_features = convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer)

        logging.info("***** Running evaluation *****")
        logging.info("  Num examples = %d", len(eval_examples))
        logging.info("  Batch size = %d", args.eval_batch_size)

        input_ids = []
        input_mask = []
        segment_ids = []
        label_id = []

        for f in eval_features:
            input_ids.append([])
            input_mask.append([])
            segment_ids.append([])
            for i in range(n_class):
                input_ids[-1].append(f[i].input_ids)
                input_mask[-1].append(f[i].input_mask)
                segment_ids[-1].append(f[i].segment_ids)
            label_id.append(f[0].label_id)

        all_input_ids = torch.tensor(input_ids, dtype=torch.long)
        all_input_mask = torch.tensor(input_mask, dtype=torch.long)
        all_segment_ids = torch.tensor(segment_ids, dtype=torch.long)
        all_label_ids = torch.tensor(label_id, dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            eval_sampler = SequentialSampler(eval_data)
        else:
            eval_sampler = DistributedSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        logits_all = []
        for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                # tmp_eval_loss, logits = model(input_ids=input_ids.view(-1,args.max_seq_length), token_type_ids=segment_ids.view(-1,args.max_seq_length),
                #                               attention_mask=input_mask.view(-1,args.max_seq_length), labels=label_ids.view(-1))
                # tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids, n_class)
                tmp_eval_loss, logits = model(input_ids=input_ids,
                                              token_type_ids=segment_ids,
                                              attention_mask=input_mask,
                                              labels=label_ids)
            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            for i in range(len(logits)):
                logits_all += [logits[i]]

            tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1))

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples

        if args.do_train:
            result = {'eval_loss': eval_loss,
                      'eval_accuracy': eval_accuracy,
                      'global_step': global_step,
                      'loss': tr_loss / nb_tr_steps}
        else:
            result = {'eval_loss': eval_loss,
                      'eval_accuracy': eval_accuracy}

        output_eval_file = os.path.join(args.output_dir, "eval_results_test.txt")
        with open(output_eval_file, "w") as writer:
            logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
        output_eval_file = os.path.join(args.output_dir, "logits_test.txt")
        with open(output_eval_file, "w") as f:
            for i in range(len(logits_all)):
                for j in range(len(logits_all[i])):
                    f.write(str(logits_all[i][j]))
                    if j == len(logits_all[i]) - 1:
                        f.write("\n")
                    else:
                        f.write(" ")
Exemple #13
0
    def train(self,
              train_data: List[InputExample],
              eval_data: List[InputExample],
              dev32_data: List[InputExample],
              eval_config: EvalConfig,
              pattern_iter_output_dir,
              per_gpu_train_batch_size: int = 8,
              n_gpu: int = 1,
              num_train_epochs: int = 3,
              gradient_accumulation_steps: int = 1,
              weight_decay: float = 0.0,
              learning_rate: float = 5e-5,
              adam_epsilon: float = 1e-8,
              warmup_steps=0,
              max_grad_norm: float = 1,
              logging_steps: int = 50,
              max_steps=-1,
              **_):
        """
        Train the underlying language model.

        :param train_data: the training examples to use
        :param per_gpu_train_batch_size: the number of training examples per batch and gpu
        :param n_gpu: the number of gpus to use
        :param num_train_epochs: the number of epochs to train
        :param gradient_accumulation_steps: the number of gradient accumulation steps before performing an update
        :param weight_decay: the weight decay to use
        :param learning_rate: the learning rate to use
        :param adam_epsilon: epsilon parameter for the Adam optimizer
        :param warmup_steps: the number of warmup steps
        :param max_grad_norm: the maximum norm for the gradient
        :param logging_steps: the number of steps after which logging information is printed
        :param max_steps: the maximum number of training steps, overrides ``num_train_epochs``
        :return: a tuple consisting of the total number of steps and the average training loss
        """

        train_batch_size = per_gpu_train_batch_size * max(1, n_gpu)
        train_dataset = self._generate_dataset(train_data)
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=train_batch_size)

        if max_steps > 0:
            t_total = max_steps
            num_train_epochs = max_steps // (max(
                1,
                len(train_dataloader) // gradient_accumulation_steps)) + 1
        else:
            t_total = len(train_dataloader
                          ) // gradient_accumulation_steps * num_train_epochs

        print("\n")
        print("num_steps_per_dataset:")
        print(len(train_dataloader) // gradient_accumulation_steps)
        print("total_steps:")
        print(t_total)
        print("num_train_epochs:")
        print(num_train_epochs)
        print("\n")

        cur_model = self.model.module if hasattr(self.model,
                                                 'module') else self.model

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in cur_model.model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            weight_decay
        }, {
            'params': [
                p for n, p in cur_model.model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        if self.config.prompt_encoder_type == "lstm":
            embedding_parameters = [{
                'params': [p for p in cur_model.lstm_head.parameters()]
            }, {
                'params': [p for p in cur_model.mlp_head.parameters()]
            }, {
                'params':
                [p for p in cur_model.prompt_embeddings.parameters()]
            }]
        elif self.config.prompt_encoder_type == "mlp":
            embedding_parameters = [{
                'params': [p for p in cur_model.mlp.parameters()]
            }, {
                'params':
                [p for p in cur_model.prompt_embeddings.parameters()]
            }]

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=1e-5,
                          eps=adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=t_total)

        embedding_optimizer = AdamW(embedding_parameters,
                                    lr=learning_rate,
                                    eps=adam_epsilon)
        embedding_scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=t_total)

        writer = SummaryWriter(
            log_dir=os.path.join(self.config.output_dir, "writer_logs"))

        ### TODO
        prev_loss = 0.0
        best_dev32_acc = 0.0
        best_dev32_f1 = 0.0
        best_global_step = 0
        best_loss = 0.0
        early_stop_epoch = 0

        global_step = 0
        tr_loss, logging_loss = 0.0, 0.0
        self.model.zero_grad()

        logger.info("dev32_data performance before training.")
        dev32_scores = self.eval_dev(dev32_data, eval_config, n_gpu)
        logger.info(dev32_scores)

        logger.info("eval_data performance before training.")
        dev_scores = self.eval_dev(eval_data, eval_config, n_gpu)
        logger.info(dev_scores)

        train_iterator = trange(int(num_train_epochs), desc="Epoch")
        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for step, batch in enumerate(epoch_iterator):
                self.model.train()
                batch = {k: t.cuda() for k, t in batch.items()}

                loss = self.task_helper.train_step(
                    batch) if self.task_helper else None
                if loss is None:
                    loss = TRAIN_STEP_FUNCTIONS[MLM_WRAPPER](self)(batch)

                if n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if gradient_accumulation_steps > 1:
                    loss = loss / gradient_accumulation_steps

                loss.backward()
                tr_loss += loss.item()

                if (step + 1) % gradient_accumulation_steps == 0:
                    ## TODO
                    writer.add_scalar("train_loss", (tr_loss - prev_loss),
                                      global_step=global_step)
                    prev_loss = tr_loss

                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   max_grad_norm)

                    optimizer.step()
                    scheduler.step()
                    embedding_optimizer.step()
                    embedding_scheduler.step()

                    self.model.zero_grad()
                    global_step += 1

                    if logging_steps > 0 and global_step % logging_steps == 0:
                        logs = {}
                        loss_scalar = (tr_loss - logging_loss) / logging_steps
                        learning_rate_scalar = scheduler.get_lr()[0]
                        logs['learning_rate'] = learning_rate_scalar
                        logs['loss'] = loss_scalar
                        logging_loss = tr_loss
                        print(json.dumps({**logs, **{'step': global_step}}))

                    ## TODO
                    if global_step % self.config.eval_every_step == 0:
                        dev32_scores = self.eval_dev(dev32_data, eval_config,
                                                     n_gpu)

                        if self.config.task_name in [
                                "cb", "record", "multirc"
                        ]:
                            f1_str = "f1" if self.config.task_name != "cb" else "f1-macro"
                            if dev32_scores[
                                    "acc"] >= best_dev32_acc and dev32_scores[
                                        f1_str] >= best_dev32_f1:

                                if dev32_scores[
                                        "acc"] > best_dev32_acc and dev32_scores[
                                            f1_str] > best_dev32_f1:
                                    early_stop_epoch = 0
                                else:
                                    early_stop_epoch += 1

                                best_dev32_acc = dev32_scores["acc"]
                                best_dev32_f1 = dev32_scores[f1_str]
                                best_global_step = global_step
                                best_loss = tr_loss

                                logger.info(
                                    "Saving trained model at {}...".format(
                                        pattern_iter_output_dir))
                                logger.info("best_dev32_acc: %.4f | best_dev32_f1: %.4f | best_global_step: %d" % \
                                            (best_dev32_acc, best_dev32_f1, best_global_step))
                                logger.info(dev32_scores)

                                self.save(pattern_iter_output_dir)
                                logger.info("eval_data performance:")
                                eval_scores = self.eval_dev(
                                    eval_data, eval_config, n_gpu)
                                logger.info(eval_scores)
                            else:
                                early_stop_epoch += 1
                                logger.info(dev32_scores)
                                logger.info(early_stop_epoch)

                        elif self.config.task_name in [
                                "rte", "wic", "boolq", "wsc", "copa"
                        ]:
                            if dev32_scores["acc"] >= best_dev32_acc:
                                if dev32_scores["acc"] > best_dev32_acc:
                                    early_stop_epoch = 0
                                else:
                                    early_stop_epoch += 1

                                best_dev32_acc = dev32_scores["acc"]
                                best_global_step = global_step
                                best_loss = tr_loss

                                logger.info(
                                    "Saving trained model at {}...".format(
                                        pattern_iter_output_dir))
                                logger.info("best_dev32_acc: %.4f | best_global_step: %d" % \
                                            (best_dev32_acc, best_global_step))

                                self.save(pattern_iter_output_dir)
                                logger.info("eval_data performance:")
                                eval_scores = self.eval_dev(
                                    eval_data, eval_config, n_gpu)
                                logger.info(eval_scores)
                            else:
                                early_stop_epoch += 1
                                logger.info(dev32_scores)
                                logger.info(early_stop_epoch)

                if 0 < max_steps < global_step or early_stop_epoch >= 10:
                    epoch_iterator.close()
                    break

            if 0 < max_steps < global_step or early_stop_epoch >= 10:
                train_iterator.close()
                break

        return best_global_step, (best_loss / best_global_step
                                  if best_global_step > 0 else -1)
def main():
    args = parse_args()

    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
    accelerator = Accelerator()
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state)

    # Setup logging, we only want one process per machine to log things on the screen.
    # accelerator.is_local_main_process is only True for one process per machine.
    logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
    else:
        data_files = {}
        if args.train_file is not None:
            data_files["train"] = args.train_file
        if args.validation_file is not None:
            data_files["validation"] = args.validation_file
        extension = args.train_file.split(".")[-1]
        raw_datasets = load_dataset(extension, data_files=data_files)
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer
    #
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = XLNetConfig.from_pretrained(args.model_name_or_path)
    tokenizer = XLNetTokenizerFast.from_pretrained(args.model_name_or_path)
    model = XLNetForQuestionAnswering.from_pretrained(
        args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config
    )

    # Preprocessing the datasets.
    # Preprocessing is slighlty different for training and evaluation.
    column_names = raw_datasets["train"].column_names

    question_column_name = "question" if "question" in column_names else column_names[0]
    context_column_name = "context" if "context" in column_names else column_names[1]
    answer_column_name = "answers" if "answers" in column_names else column_names[2]

    # Padding side determines if we do (question|context) or (context|question).
    pad_on_right = tokenizer.padding_side == "right"

    if args.max_seq_length > tokenizer.model_max_length:
        logger.warning(
            f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the"
            f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
        )

    max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)

    # Training preprocessing
    def prepare_train_features(examples):
        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = tokenizer(
            examples[question_column_name if pad_on_right else context_column_name],
            examples[context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_length,
            stride=args.doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            return_special_tokens_mask=True,
            return_token_type_ids=True,
            padding="max_length",
        )

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
        # The offset mappings will give us a map from token to character position in the original context. This will
        # help us compute the start_positions and end_positions.
        offset_mapping = tokenized_examples.pop("offset_mapping")
        # The special tokens will help us build the p_mask (which indicates the tokens that can't be in answers).
        special_tokens = tokenized_examples.pop("special_tokens_mask")

        # Let's label those examples!
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []
        tokenized_examples["is_impossible"] = []
        tokenized_examples["cls_index"] = []
        tokenized_examples["p_mask"] = []

        for i, offsets in enumerate(offset_mapping):
            # We will label impossible answers with the index of the CLS token.
            input_ids = tokenized_examples["input_ids"][i]
            cls_index = input_ids.index(tokenizer.cls_token_id)
            tokenized_examples["cls_index"].append(cls_index)

            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples["token_type_ids"][i]
            for k, s in enumerate(special_tokens[i]):
                if s:
                    sequence_ids[k] = 3
            context_idx = 1 if pad_on_right else 0

            # Build the p_mask: non special tokens and context gets 0.0, the others get 1.0.
            # The cls token gets 1.0 too (for predictions of empty answers).
            tokenized_examples["p_mask"].append(
                [
                    0.0 if (not special_tokens[i][k] and s == context_idx) or k == cls_index else 1.0
                    for k, s in enumerate(sequence_ids)
                ]
            )

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            answers = examples[answer_column_name][sample_index]
            # If no answers are given, set the cls_index as answer.
            if len(answers["answer_start"]) == 0:
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
                tokenized_examples["is_impossible"].append(1.0)
            else:
                # Start/end character index of the answer in the text.
                start_char = answers["answer_start"][0]
                end_char = start_char + len(answers["text"][0])

                # Start token index of the current span in the text.
                token_start_index = 0
                while sequence_ids[token_start_index] != context_idx:
                    token_start_index += 1

                # End token index of the current span in the text.
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != context_idx:
                    token_end_index -= 1
                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
                if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                    tokenized_examples["is_impossible"].append(1.0)
                else:
                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                    # Note: we could go after the last offset if the answer is the last word (edge case).
                    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                        token_start_index += 1
                    tokenized_examples["start_positions"].append(token_start_index - 1)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    tokenized_examples["end_positions"].append(token_end_index + 1)
                    tokenized_examples["is_impossible"].append(0.0)

        return tokenized_examples

    if "train" not in raw_datasets:
        raise ValueError("--do_train requires a train dataset")
    train_dataset = raw_datasets["train"]
    if args.max_train_samples is not None:
        # We will select sample from whole data if agument is specified
        train_dataset = train_dataset.select(range(args.max_train_samples))
    # Create train feature from dataset
    train_dataset = train_dataset.map(
        prepare_train_features,
        batched=True,
        num_proc=args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not args.overwrite_cache,
    )
    if args.max_train_samples is not None:
        # Number of samples might increase during Feature Creation, We select only specified max samples
        train_dataset = train_dataset.select(range(args.max_train_samples))

    # Validation preprocessing
    def prepare_validation_features(examples):
        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = tokenizer(
            examples[question_column_name if pad_on_right else context_column_name],
            examples[context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_length,
            stride=args.doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            return_special_tokens_mask=True,
            return_token_type_ids=True,
            padding="max_length",
        )

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

        # The special tokens will help us build the p_mask (which indicates the tokens that can't be in answers).
        special_tokens = tokenized_examples.pop("special_tokens_mask")

        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
        # corresponding example_id and we will store the offset mappings.
        tokenized_examples["example_id"] = []

        # We still provide the index of the CLS token and the p_mask to the model, but not the is_impossible label.
        tokenized_examples["cls_index"] = []
        tokenized_examples["p_mask"] = []

        for i, input_ids in enumerate(tokenized_examples["input_ids"]):
            # Find the CLS token in the input ids.
            cls_index = input_ids.index(tokenizer.cls_token_id)
            tokenized_examples["cls_index"].append(cls_index)

            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples["token_type_ids"][i]
            for k, s in enumerate(special_tokens[i]):
                if s:
                    sequence_ids[k] = 3
            context_idx = 1 if pad_on_right else 0

            # Build the p_mask: non special tokens and context gets 0.0, the others 1.0.
            tokenized_examples["p_mask"].append(
                [
                    0.0 if (not special_tokens[i][k] and s == context_idx) or k == cls_index else 1.0
                    for k, s in enumerate(sequence_ids)
                ]
            )

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            tokenized_examples["example_id"].append(examples["id"][sample_index])

            # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
            # position is part of the context or not.
            tokenized_examples["offset_mapping"][i] = [
                (o if sequence_ids[k] == context_idx else None)
                for k, o in enumerate(tokenized_examples["offset_mapping"][i])
            ]

        return tokenized_examples

    if "validation" not in raw_datasets:
        raise ValueError("--do_eval requires a validation dataset")
    eval_examples = raw_datasets["validation"]
    if args.max_val_samples is not None:
        # We will select sample from whole data
        eval_examples = eval_examples.select(range(args.max_val_samples))
    # Validation Feature Creation
    eval_dataset = eval_examples.map(
        prepare_validation_features,
        batched=True,
        num_proc=args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not args.overwrite_cache,
    )

    if args.max_val_samples is not None:
        # During Feature creation dataset samples might increase, we will select required samples again
        eval_dataset = eval_dataset.select(range(args.max_val_samples))

    if args.do_predict:
        if "test" not in raw_datasets:
            raise ValueError("--do_predict requires a test dataset")
        test_examples = raw_datasets["test"]
        if args.max_test_samples is not None:
            # We will select sample from whole data
            test_examples = test_examples.select(range(args.max_test_samples))
        # Test Feature Creation
        test_dataset = test_examples.map(
            prepare_validation_features,
            batched=True,
            num_proc=args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not args.overwrite_cache,
        )
        if args.max_test_samples is not None:
            # During Feature creation dataset samples might increase, we will select required samples again
            test_dataset = test_dataset.select(range(args.max_test_samples))

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

    # DataLoaders creation:
    if args.pad_to_max_length:
        # If padding was already done ot max length, we use the default data collator that will just convert everything
        # to tensors.
        data_collator = default_data_collator
    else:
        # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
        # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
        # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None))

    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
    )

    eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
    eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)

    if args.do_predict:
        test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
        test_dataloader = DataLoader(
            test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
        )

    # Post-processing:
    def post_processing_function(examples, features, predictions, stage="eval"):
        # Post-processing: we match the start logits and end logits to answers in the original context.
        predictions, scores_diff_json = postprocess_qa_predictions_with_beam_search(
            examples=examples,
            features=features,
            predictions=predictions,
            version_2_with_negative=args.version_2_with_negative,
            n_best_size=args.n_best_size,
            max_answer_length=args.max_answer_length,
            start_n_top=model.config.start_n_top,
            end_n_top=model.config.end_n_top,
            output_dir=args.output_dir,
            prefix=stage,
        )
        # Format the result to the format the metric expects.
        if args.version_2_with_negative:
            formatted_predictions = [
                {"id": k, "prediction_text": v, "no_answer_probability": scores_diff_json[k]}
                for k, v in predictions.items()
            ]
        else:
            formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]

        references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
        return EvalPrediction(predictions=formatted_predictions, label_ids=references)

    metric = load_metric("squad_v2" if args.version_2_with_negative else "squad")

    def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
        """
        Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor

        Args:
            start_or_end_logits(:obj:`tensor`):
                This is the output predictions of the model. We can only enter either start or end logits.
            eval_dataset: Evaluation dataset
            max_len(:obj:`int`):
                The maximum length of the output tensor. ( See the model.eval() part for more details )
        """

        step = 0
        # create a numpy array and fill it with -100.
        logits_concat = np.full((len(dataset), max_len), -100, dtype=np.float32)
        # Now since we have create an array now we will populate it with the outputs gathered using accelerator.gather
        for i, output_logit in enumerate(start_or_end_logits):  # populate columns
            # We have to fill it such that we have to take the whole tensor and replace it on the newly created array
            # And after every iteration we have to change the step

            batch_size = output_logit.shape[0]
            cols = output_logit.shape[1]
            if step + batch_size < len(dataset):
                logits_concat[step : step + batch_size, :cols] = output_logit
            else:
                logits_concat[step:, :cols] = output_logit[: len(dataset) - step]

            step += batch_size

        return logits_concat

    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    # Prepare everything with our `accelerator`.
    model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader
    )

    # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
    # shorter in multiprocess)

    # Scheduler and math around the number of training steps.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    else:
        args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.max_train_steps,
    )

    # Train!
    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
    completed_steps = 0

    for epoch in range(args.num_train_epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            outputs = model(**batch)
            loss = outputs.loss
            loss = loss / args.gradient_accumulation_steps
            accelerator.backward(loss)
            if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)
                completed_steps += 1

            if completed_steps >= args.max_train_steps:
                break

        # intialize all lists to collect the batches

    all_start_top_log_probs = []
    all_start_top_index = []
    all_end_top_log_probs = []
    all_end_top_index = []
    all_cls_logits = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)
            start_top_log_probs = outputs.start_top_log_probs
            start_top_index = outputs.start_top_index
            end_top_log_probs = outputs.end_top_log_probs
            end_top_index = outputs.end_top_index
            cls_logits = outputs.cls_logits

            if not args.pad_to_max_length:  # necessary to pad predictions and labels for being gathered
                start_top_log_probs = accelerator.pad_across_processes(start_top_log_probs, dim=1, pad_index=-100)
                start_top_index = accelerator.pad_across_processes(start_top_index, dim=1, pad_index=-100)
                end_top_log_probs = accelerator.pad_across_processes(end_top_log_probs, dim=1, pad_index=-100)
                end_top_index = accelerator.pad_across_processes(end_top_index, dim=1, pad_index=-100)
                cls_logits = accelerator.pad_across_processes(cls_logits, dim=1, pad_index=-100)

            all_start_top_log_probs.append(accelerator.gather(start_top_log_probs).cpu().numpy())
            all_start_top_index.append(accelerator.gather(start_top_index).cpu().numpy())
            all_end_top_log_probs.append(accelerator.gather(end_top_log_probs).cpu().numpy())
            all_end_top_index.append(accelerator.gather(end_top_index).cpu().numpy())
            all_cls_logits.append(accelerator.gather(cls_logits).cpu().numpy())

    max_len = max([x.shape[1] for x in all_end_top_log_probs])  # Get the max_length of the tensor

    # concatenate all numpy arrays collected above
    start_top_log_probs_concat = create_and_fill_np_array(all_start_top_log_probs, eval_dataset, max_len)
    start_top_index_concat = create_and_fill_np_array(all_start_top_index, eval_dataset, max_len)
    end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, eval_dataset, max_len)
    end_top_index_concat = create_and_fill_np_array(all_end_top_index, eval_dataset, max_len)
    all_cls_logits = np.concatenate(all_cls_logits, axis=0)

    # delete the list of numpy arrays
    del start_top_log_probs
    del start_top_index
    del end_top_log_probs
    del end_top_index

    eval_dataset.set_format(type=None, columns=list(eval_dataset.features.keys()))
    outputs_numpy = (
        start_top_log_probs_concat,
        start_top_index_concat,
        end_top_log_probs_concat,
        end_top_index_concat,
        cls_logits,
    )
    prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
    eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
    logger.info(f"Evaluation metrics: {eval_metric}")

    if args.do_predict:
        # intialize all lists to collect the batches

        all_start_top_log_probs = []
        all_start_top_index = []
        all_end_top_log_probs = []
        all_end_top_index = []
        all_cls_logits = []
        for step, batch in enumerate(test_dataloader):
            with torch.no_grad():
                outputs = model(**batch)
                start_top_log_probs = outputs.start_top_log_probs
                start_top_index = outputs.start_top_index
                end_top_log_probs = outputs.end_top_log_probs
                end_top_index = outputs.end_top_index
                cls_logits = outputs.cls_logits

                if not args.pad_to_max_length:  # necessary to pad predictions and labels for being gathered
                    start_top_log_probs = accelerator.pad_across_processes(start_top_log_probs, dim=1, pad_index=-100)
                    start_top_index = accelerator.pad_across_processes(start_top_index, dim=1, pad_index=-100)
                    end_top_log_probs = accelerator.pad_across_processes(end_top_log_probs, dim=1, pad_index=-100)
                    end_top_index = accelerator.pad_across_processes(end_top_index, dim=1, pad_index=-100)
                    cls_logits = accelerator.pad_across_processes(cls_logits, dim=1, pad_index=-100)

                all_start_top_log_probs.append(accelerator.gather(start_top_log_probs).cpu().numpy())
                all_start_top_index.append(accelerator.gather(start_top_index).cpu().numpy())
                all_end_top_log_probs.append(accelerator.gather(end_top_log_probs).cpu().numpy())
                all_end_top_index.append(accelerator.gather(end_top_index).cpu().numpy())
                all_cls_logits.append(accelerator.gather(cls_logits).cpu().numpy())

        max_len = max([x.shape[1] for x in all_end_top_log_probs])  # Get the max_length of the tensor

        # concatenate all numpy arrays collected above
        start_top_log_probs_concat = create_and_fill_np_array(all_start_top_log_probs, test_dataset, max_len)
        start_top_index_concat = create_and_fill_np_array(all_start_top_index, test_dataset, max_len)
        end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, test_dataset, max_len)
        end_top_index_concat = create_and_fill_np_array(all_end_top_index, test_dataset, max_len)
        all_cls_logits = np.concatenate(all_cls_logits, axis=0)

        # delete the list of numpy arrays
        del start_top_log_probs
        del start_top_index
        del end_top_log_probs
        del end_top_index

        test_dataset.set_format(type=None, columns=list(test_dataset.features.keys()))
        outputs_numpy = (
            start_top_log_probs_concat,
            start_top_index_concat,
            end_top_log_probs_concat,
            end_top_index_concat,
            cls_logits,
        )

        prediction = post_processing_function(test_examples, test_dataset, outputs_numpy)
        test_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
        logger.info(f"Test metrics: {test_metric}")

    if args.output_dir is not None:
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
Exemple #15
0
    def train(self,
              task_train_data: List[InputExample],
              device,
              per_gpu_train_batch_size: int = 8,
              n_gpu: int = 1,
              num_train_epochs: int = 3,
              gradient_accumulation_steps: int = 1,
              weight_decay: float = 0.0,
              learning_rate: float = 5e-5,
              adam_epsilon: float = 1e-8,
              warmup_steps=0,
              max_grad_norm: float = 1,
              logging_steps: int = 50,
              per_gpu_unlabeled_batch_size: int = 8,
              unlabeled_data: List[InputExample] = None,
              lm_training: bool = False,
              use_logits: bool = False,
              alpha: float = 0.8,
              temperature: float = 1,
              max_steps=-1,
              **_):
        """
        Train the underlying language model.

        :param task_train_data: the training examples to use
        :param device: the training device (cpu/gpu)
        :param per_gpu_train_batch_size: the number of training examples per batch and gpu
        :param n_gpu: the number of gpus to use
        :param num_train_epochs: the number of epochs to train
        :param gradient_accumulation_steps: the number of gradient accumulation steps before performing an update
        :param weight_decay: the weight decay to use
        :param learning_rate: the learning rate to use
        :param adam_epsilon: epsilon parameter for the Adam optimizer
        :param warmup_steps: the number of warmup steps
        :param max_grad_norm: the maximum norm for the gradient
        :param logging_steps: the number of steps after which logging information is printed
        :param per_gpu_unlabeled_batch_size: the number of unlabeled examples per batch and gpu
        :param unlabeled_data: the unlabeled examples to use
        :param lm_training: whether to perform auxiliary language modeling (only for MLMs)
        :param use_logits: whether to use the example's logits instead of their labels to compute the loss
        :param alpha: the alpha parameter for auxiliary language modeling
        :param temperature: the temperature for knowledge distillation
        :param max_steps: the maximum number of training steps, overrides ``num_train_epochs``
        :return: a tuple consisting of the total number of steps and the average training loss
        """

        train_batch_size = per_gpu_train_batch_size * max(1, n_gpu)
        train_dataset = self._generate_dataset(task_train_data)
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=train_batch_size)

        unlabeled_dataloader, unlabeled_iter = None, None

        if lm_training or use_logits:
            # we need unlabeled data both for auxiliary language modeling and for knowledge distillation
            assert unlabeled_data is not None
            unlabeled_batch_size = per_gpu_unlabeled_batch_size * max(1, n_gpu)
            unlabeled_dataset = self._generate_dataset(unlabeled_data,
                                                       labelled=False)
            unlabeled_sampler = RandomSampler(unlabeled_dataset)
            unlabeled_dataloader = DataLoader(unlabeled_dataset,
                                              sampler=unlabeled_sampler,
                                              batch_size=unlabeled_batch_size)
            unlabeled_iter = unlabeled_dataloader.__iter__()

        if use_logits:
            train_dataloader = unlabeled_dataloader

        if max_steps > 0:
            t_total = max_steps
            num_train_epochs = max_steps // (max(
                1,
                len(train_dataloader) // gradient_accumulation_steps)) + 1
        else:
            t_total = len(train_dataloader
                          ) // gradient_accumulation_steps * num_train_epochs

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in self.model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            weight_decay
        }, {
            'params': [
                p for n, p in self.model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=learning_rate,
                          eps=adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=t_total)

        # multi-gpu training
        if n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model)

        step = 0
        global_step = 0
        tr_loss, logging_loss = 0.0, 0.0
        self.model.zero_grad()

        train_iterator = trange(int(num_train_epochs), desc="Epoch")

        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for _, batch in enumerate(epoch_iterator):
                self.model.train()
                unlabeled_batch = None

                batch = {k: t.to(device) for k, t in batch.items()}

                if lm_training:
                    while unlabeled_batch is None:
                        try:
                            unlabeled_batch = unlabeled_iter.__next__()
                        except StopIteration:
                            logger.info("Resetting unlabeled dataset")
                            unlabeled_iter = unlabeled_dataloader.__iter__()

                    lm_input_ids = unlabeled_batch['input_ids']
                    unlabeled_batch['input_ids'], unlabeled_batch[
                        'mlm_labels'] = self._mask_tokens(lm_input_ids)
                    unlabeled_batch = {
                        k: t.to(device)
                        for k, t in unlabeled_batch.items()
                    }

                train_step_inputs = {
                    'unlabeled_batch': unlabeled_batch,
                    'lm_training': lm_training,
                    'alpha': alpha,
                    'use_logits': use_logits,
                    'temperature': temperature
                }
                loss = self.task_helper.train_step(
                    batch, **train_step_inputs) if self.task_helper else None

                if loss is None:
                    loss = TRAIN_STEP_FUNCTIONS[self.config.wrapper_type](
                        self)(batch, **train_step_inputs)

                if n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if gradient_accumulation_steps > 1:
                    loss = loss / gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                if (step + 1) % gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   max_grad_norm)
                    optimizer.step()
                    scheduler.step()
                    self.model.zero_grad()
                    global_step += 1

                    if logging_steps > 0 and global_step % logging_steps == 0:
                        logs = {}
                        loss_scalar = (tr_loss - logging_loss) / logging_steps
                        learning_rate_scalar = scheduler.get_lr()[0]
                        logs['learning_rate'] = learning_rate_scalar
                        logs['loss'] = loss_scalar
                        logging_loss = tr_loss

                        print(json.dumps({**logs, **{'step': global_step}}))

                if 0 < max_steps < global_step:
                    epoch_iterator.close()
                    break
                step += 1
            if 0 < max_steps < global_step:
                train_iterator.close()
                break

        return global_step, (tr_loss / global_step if global_step > 0 else -1)
Exemple #16
0
def train(args, train_dataset, model, tokenizer, fh, pool):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        args.tensorboard_dir = os.path.join(args.output_dir, 'tensorboard')
        if not os.path.exists(args.tensorboard_dir):
            os.makedirs(args.tensorboard_dir)

    args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset)

    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.batch_size,
                                  drop_last=True)
    total_examples = len(train_dataset) * (torch.distributed.get_world_size()
                                           if args.local_rank != -1 else 1)
    batch_size = args.batch_size * args.gradient_accumulation_steps * (
        torch.distributed.get_world_size() if args.local_rank != -1 else 1)
    # if args.max_steps > 0:
    #     t_total = args.max_steps
    #     args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    if args.num_train_epochs > 0:
        t_total = total_examples // batch_size * args.num_train_epochs
    args.max_steps = t_total
    model.to(args.device)
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()
    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)
    checkpoint_last = os.path.join(args.output_dir, 'checkpoint-last')
    # scheduler_last = os.path.join(checkpoint_last, 'scheduler.pt')
    optimizer_last = os.path.join(checkpoint_last, 'optimizer.pt')
    # if os.path.exists(scheduler_last):
    #     scheduler.load_state_dict(torch.load(scheduler_last, map_location="cpu"))
    if os.path.exists(optimizer_last):
        logger.warning(f"Loading optimizer from {optimizer_last}")
        optimizer.load_state_dict(
            torch.load(optimizer_last, map_location="cpu"))
    if args.local_rank == 0:
        torch.distributed.barrier()
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank % args.gpu_per_node],
            output_device=args.local_rank % args.gpu_per_node)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", total_examples)
    logger.info("  Num epoch = %d", t_total * batch_size // total_examples)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        batch_size)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = args.start_step
    tr_loss, logging_loss, avg_loss, tr_nb = 0.0, 0.0, 0.0, global_step
    # model.resize_token_embeddings(len(tokenizer))
    model.zero_grad()
    set_seed(
        args)  # Added here for reproducibility (even between python 2 and 3)

    for idx in range(args.start_epoch, int(args.num_train_epochs)):
        for step, batch in enumerate(train_dataloader):
            inputs, labels = (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs, labels=labels)
            loss = outputs[0]

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            tr_loss += loss.item()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
                global_step += 1
                output_flag = True
                avg_loss = round(
                    np.exp((tr_loss - logging_loss) / (global_step - tr_nb)),
                    4)
                if global_step % args.logging_steps == 0:
                    logger.info("  steps: %s  ppl: %s  lr: %s", global_step,
                                round(avg_loss, 5),
                                scheduler.get_last_lr()[0])
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    logging_loss = tr_loss
                    tr_nb = global_step

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    if args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args,
                                           model,
                                           tokenizer,
                                           eval_when_training=True)
                        for key, value in results.items():
                            logger.info("  %s = %s", key, round(value, 4))
                        output_dir = os.path.join(
                            args.output_dir,
                            '{}-{}-{}'.format(checkpoint_prefix, global_step,
                                              round(results['perplexity'], 4)))
                    else:
                        output_dir = os.path.join(
                            args.output_dir,
                            "{}-{}".format(checkpoint_prefix, global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    if args.model_type == "rnn":
                        torch.save(model_to_save.state_dict(),
                                   os.path.join(output_dir, "model.pt"))
                    else:
                        model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    # _rotate_checkpoints(args, checkpoint_prefix)
                    last_output_dir = os.path.join(args.output_dir,
                                                   'checkpoint-last')
                    if not os.path.exists(last_output_dir):
                        os.makedirs(last_output_dir)
                    if args.model_type == "rnn":
                        torch.save(model_to_save.state_dict(),
                                   os.path.join(last_output_dir, "model.pt"))
                    else:
                        model_to_save.save_pretrained(last_output_dir)
                    tokenizer.save_pretrained(last_output_dir)
                    idx_file = os.path.join(last_output_dir, 'idx_file.txt')
                    with open(idx_file, 'w', encoding='utf-8') as idxf:
                        idxf.write(str(0) + '\n')

                    torch.save(optimizer.state_dict(),
                               os.path.join(last_output_dir, "optimizer.pt"))
                    # torch.save(scheduler.state_dict(), os.path.join(last_output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                last_output_dir)

                    step_file = os.path.join(last_output_dir, 'step_file.txt')
                    with open(step_file, 'w', encoding='utf-8') as stepf:
                        stepf.write(str(global_step) + '\n')

            if args.max_steps > 0 and global_step > args.max_steps:
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            break

    return global_step, tr_loss / global_step
def train(model, tokenizer, checkpoint, round):
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
    else:
        amp = None

    train_data = Multi_task_dataset(data_file=args.train_file,
                                    max_length=args.max_length,
                                    tokenizer=tokenizer,
                                    model_type=args.model_type)

    train_dataloader = DataLoader(dataset=train_data,
                                  batch_size=args.batch_size,
                                  shuffle=True)

    t_total = len(train_dataloader) * args.epochs
    warmup_steps = int(args.warmup_steps * t_total)
    optimizer = AdamW(model.parameters(),
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=t_total)

    if args.fp16:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fptype)

    # 读取断点 optimizer、scheduler
    checkpoint_dir = args.save_dir + "/checkpoint-" + str(
        checkpoint) + '-' + str(round)
    if os.path.isfile(os.path.join(checkpoint_dir, "optimizer.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(checkpoint_dir, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(checkpoint_dir, "scheduler.pt")))
        if args.fp16:
            amp.load_state_dict(
                torch.load(os.path.join(checkpoint_dir, "amp.pt")))

    # 开始训练
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataloader))
    logger.info("  Num Epochs = %d", args.epochs)
    logger.info("  Batch size = %d", args.batch_size)
    logger.info("  learning_rate = %s", str(args.learning_rate))
    logger.info("  Total steps = %d", t_total)
    logger.info("  warmup steps = %d", warmup_steps)
    logger.info("  Model_type = %s", args.model_type)
    logger.info("  Decoder_type = %s", args.decoder_type)
    logger.info("  vice_loss_weight = %s", str(args.vice_weight))

    # 没有历史断点,则从0开始
    if checkpoint < 0:
        checkpoint = 0
        round = 0
    else:
        checkpoint += 1
        round += 1

    max_test_acc = 0
    max_test_f1 = 0

    logger.debug("  Start Batch = %d", checkpoint)
    for epoch in range(checkpoint, args.epochs):
        model.train()
        epoch_loss = []

        step = 0
        for batch in tqdm(train_dataloader, desc="Iteration", ncols=50):
            model.zero_grad()
            # 设置tensor gpu运行
            batch = tuple(t.to(args.device) for t in batch)

            if 'roberta' in args.model_type:
                input_ids, attention_mask, labels_main, labels_vice1, labels_vice2 = batch
                outputs = model(input_ids=input_ids.long(),
                                attention_mask=attention_mask.long(),
                                labels_main=labels_main,
                                labels_vice1=labels_vice1,
                                labels_vice2=labels_vice2,
                                model_type='roberta')
            else:
                input_ids, token_type_ids, attention_mask, labels_main, labels_vice1, labels_vice2 = batch
                outputs = model(input_ids=input_ids.long(),
                                token_type_ids=token_type_ids.long(),
                                attention_mask=attention_mask.long(),
                                labels_main=labels_main,
                                labels_vice1=labels_vice1,
                                labels_vice2=labels_vice2)

            loss = outputs[0]

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()  # 计算出梯度

            epoch_loss.append(loss.item())

            optimizer.step()
            scheduler.step()
            step += 1
            if step % 500 == 0:
                logger.debug("loss:" + str(np.array(epoch_loss).mean()))
                logger.debug(
                    'learning_rate:' +
                    str(optimizer.state_dict()['param_groups'][0]['lr']))
            if step % args.saving_steps == 0:
                round += 1
                dev_loss, dev_acc, dev_f1 = test(model=model,
                                                 tokenizer=tokenizer,
                                                 test_file=args.dev_file,
                                                 checkpoint=epoch,
                                                 round=round)
                logger.info(
                    '【DEV】Train Epoch %d, round %d: train_loss=%.4f, acc=%.4f, f1=%.4f'
                    % (epoch, round, dev_loss, dev_acc, dev_f1))

                test_loss, test_acc, test_f1 = test(model=model,
                                                    tokenizer=tokenizer,
                                                    test_file=args.test_file,
                                                    checkpoint=epoch,
                                                    round=round)
                logger.info(
                    '【TEST】Train Epoch %d, round %d: train_loss=%.4f, acc=%.4f, f1=%.4f'
                    % (epoch, round, test_loss, test_acc, test_f1))
                output_dir = args.save_dir + "/checkpoint-" + str(
                    epoch) + '-' + str(round)
                if test_acc > max_test_acc or test_f1 > max_test_f1:
                    max_test_acc = max(test_acc, max_test_acc)
                    max_test_f1 = max(test_f1, max_test_f1)

                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (model.module
                                     if hasattr(model, "module") else model)
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.debug("Saving model checkpoint to %s", output_dir)
                    if args.fp16:
                        torch.save(amp.state_dict(),
                                   os.path.join(output_dir, "amp.pt"))
                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.debug("Saving optimizer and scheduler states to %s",
                                 output_dir)
                model.train()

            # 保存模型
        output_dir = args.save_dir + "/checkpoint-" + str(epoch) + '-' + str(
            round)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        model_to_save = (model.module if hasattr(model, "module") else model)
        model_to_save.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
        torch.save(args, os.path.join(output_dir, "training_args.bin"))
        logger.debug("Saving model checkpoint to %s", output_dir)
        if args.fp16:
            torch.save(amp.state_dict(), os.path.join(output_dir, "amp.pt"))
        torch.save(optimizer.state_dict(),
                   os.path.join(output_dir, "optimizer.pt"))
        torch.save(scheduler.state_dict(),
                   os.path.join(output_dir, "scheduler.pt"))
        logger.debug("Saving optimizer and scheduler states to %s", output_dir)

        dev_loss, dev_acc, dev_f1 = test(model=model,
                                         tokenizer=tokenizer,
                                         test_file=args.dev_file,
                                         checkpoint=epoch,
                                         round=round)
        test_loss, test_acc, test_f1 = test(model=model,
                                            tokenizer=tokenizer,
                                            test_file=args.test_file,
                                            checkpoint=epoch,
                                            round=round)
        #print(test_loss, test_acc)
        logger.info(
            '【DEV】Train Epoch %d, round %d: train_loss=%.4f, acc=%.4f, f1=%.4f'
            % (epoch, round, dev_loss, dev_acc, dev_f1))
        logger.info(
            '【TEST】Train Epoch %d, round %d: train_loss=%.4f, acc=%.4f, f1=%.4f'
            % (epoch, round, test_loss, test_acc, test_f1))
        if test_acc > max_test_acc or test_f1 > max_test_f1:
            max_test_acc = max(test_acc, max_test_acc)
            max_test_f1 = max(test_f1, max_test_f1)
    logger.info('【BEST TEST ACC】: %.4f,   【BEST TEST F1】: %.4f' %
                (max_test_acc, max_test_f1))
def main():
    args = parse_args()

    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
    # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
    accelerator = Accelerator(
        log_with="all",
        logging_dir=args.output_dir) if args.with_tracking else Accelerator()
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state)

    # Setup logging, we only want one process per machine to log things on the screen.
    # accelerator.is_local_main_process is only True for one process per machine.
    logger.setLevel(
        logging.INFO if accelerator.is_local_main_process else logging.ERROR)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.push_to_hub:
            if args.hub_model_id is None:
                repo_name = get_full_repo_name(Path(args.output_dir).name,
                                               token=args.hub_token)
            else:
                repo_name = args.hub_model_id
            repo = Repository(args.output_dir, clone_from=repo_name)

            with open(os.path.join(args.output_dir, ".gitignore"),
                      "w+") as gitignore:
                if "step_*" not in gitignore:
                    gitignore.write("step_*\n")
                if "epoch_*" not in gitignore:
                    gitignore.write("epoch_*\n")
        elif args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)
    accelerator.wait_for_everyone()

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = datasets.DatasetDict({
            "train":
            datasets.Dataset.from_dict(
                load_dataset(args.dataset_name,
                             args.dataset_config_name)["train"][:args.n_train +
                                                                args.n_val])
        })
        if "validation" not in raw_datasets.keys():
            raw_datasets["validation"] = load_dataset(
                args.dataset_name,
                args.dataset_config_name,
                split=f"train[:{args.validation_split_percentage}%]",
            )
            raw_datasets["train"] = load_dataset(
                args.dataset_name,
                args.dataset_config_name,
                split=f"train[{args.validation_split_percentage}%:]",
            )
    else:
        data_files = {}
        dataset_args = {}
        if args.train_file is not None:
            data_files["train"] = args.train_file
        if args.validation_file is not None:
            data_files["validation"] = args.validation_file
        extension = args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
            dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
        raw_datasets = load_dataset(extension,
                                    data_files=data_files,
                                    **dataset_args)
        # If no validation data is there, validation_split_percentage will be used to divide the dataset.
        if "validation" not in raw_datasets.keys():
            raw_datasets["validation"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[:{args.validation_split_percentage}%]",
                **dataset_args,
            )
            raw_datasets["train"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[{args.validation_split_percentage}%:]",
                **dataset_args,
            )

    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer
    #
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    if args.config_name:
        config = AutoConfig.from_pretrained(args.config_name)
    elif args.model_name_or_path:
        config = AutoConfig.from_pretrained(args.model_name_or_path)
    else:
        config = CONFIG_MAPPING[args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")

    if args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
    elif args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if args.model_name_or_path:
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
        )
    else:
        logger.info("Training new model from scratch")
        model = AutoModelForCausalLM.from_config(config)

    model.resize_token_embeddings(len(tokenizer))

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    column_names = raw_datasets["train"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    def tokenize_function(examples):
        return tokenizer(examples[text_column_name])

    with accelerator.main_process_first():
        tokenized_datasets = raw_datasets.map(
            tokenize_function,
            batched=True,
            num_proc=args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not args.overwrite_cache,
            desc="Running tokenizer on dataset",
        )

    if args.block_size is None:
        block_size = tokenizer.model_max_length
        if block_size > 1024:
            logger.warning(
                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
                "Picking 1024 instead. You can change that default value by passing --block_size xxx."
            )
        block_size = 1024
    else:
        if args.block_size > tokenizer.model_max_length:
            logger.warning(
                f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
                f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
            )
        block_size = min(args.block_size, tokenizer.model_max_length)

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {
            k: list(chain(*examples[k]))
            for k in examples.keys()
        }
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k:
            [t[i:i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
    # to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

    with accelerator.main_process_first():
        lm_datasets = tokenized_datasets.map(
            group_texts,
            batched=True,
            num_proc=args.preprocessing_num_workers,
            load_from_cache_file=not args.overwrite_cache,
            desc=f"Grouping texts in chunks of {block_size}",
        )

    train_dataset = lm_datasets["train"]
    eval_dataset = lm_datasets["validation"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(
            f"Sample {index} of the training set: {train_dataset[index]}.")

    # DataLoaders creation:
    train_dataloader = DataLoader(train_dataset,
                                  shuffle=True,
                                  collate_fn=default_data_collator,
                                  batch_size=args.per_device_train_batch_size)
    eval_dataloader = DataLoader(eval_dataset,
                                 collate_fn=default_data_collator,
                                 batch_size=args.per_device_eval_batch_size)

    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
    if accelerator.distributed_type == DistributedType.TPU:
        model.tie_weights()

    # Scheduler and math around the number of training steps.
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    else:
        args.num_train_epochs = math.ceil(args.max_train_steps /
                                          num_update_steps_per_epoch)

    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.max_train_steps,
    )

    # Prepare everything with our `accelerator`.
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler)

    # Figure out how many steps we should save the Accelerator states
    if hasattr(args.checkpointing_steps, "isdigit"):
        checkpointing_steps = args.checkpointing_steps
        if args.checkpointing_steps.isdigit():
            checkpointing_steps = int(args.checkpointing_steps)
    else:
        checkpointing_steps = None

    # We need to initialize the trackers we use, and also store our configuration
    if args.with_tracking:
        experiment_config = vars(args)
        # TensorBoard cannot log Enums, need the raw value
        experiment_config["lr_scheduler_type"] = experiment_config[
            "lr_scheduler_type"].value
        accelerator.init_trackers("clm_no_trainer", experiment_config)

    # Train!
    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(
        f"  Instantaneous batch size per device = {args.per_device_train_batch_size}"
    )
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
    )
    logger.info(
        f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(
        f"  Total optimization steps = {int(args.max_train_steps/accelerator.num_processes)}"
    )
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(
        int(args.max_train_steps / accelerator.num_processes)),
                        disable=not accelerator.is_local_main_process)
    completed_steps = 0

    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
            accelerator.print(
                f"Resumed from checkpoint: {args.resume_from_checkpoint}")
            accelerator.load_state(args.resume_from_checkpoint)
            resume_step = None
            path = args.resume_from_checkpoint
        else:
            # Get the most recent checkpoint
            dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
            dirs.sort(key=os.path.getctime)
            path = dirs[
                -1]  # Sorts folders by date modified, most recent checkpoint is the last
        if "epoch" in path:
            args.num_train_epochs -= int(path.replace("epoch_", ""))
        else:
            resume_step = int(path.replace("step_", ""))
            args.num_train_epochs -= resume_step // len(train_dataloader)
            resume_step = (args.num_train_epochs *
                           len(train_dataloader)) - resume_step

    for epoch in range(args.num_train_epochs):
        model.train()
        if args.with_tracking:
            total_loss = 0
        for step, batch in enumerate(train_dataloader):
            # We need to skip steps until we reach the resumed step
            if args.resume_from_checkpoint and epoch == 0 and step < resume_step:
                continue
            outputs = model(**batch)
            loss = outputs.loss
            # We keep track of the loss at each epoch
            if args.with_tracking:
                total_loss += loss.detach().float()
            loss = loss / args.gradient_accumulation_steps
            accelerator.backward(loss)
            if step % args.gradient_accumulation_steps == 0 or step == len(
                    train_dataloader) - 1:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)
                completed_steps += 1

            if isinstance(checkpointing_steps, int):
                if completed_steps % checkpointing_steps == 0:
                    output_dir = f"step_{completed_steps}"
                    if args.output_dir is not None:
                        output_dir = os.path.join(args.output_dir, output_dir)
                    accelerator.save_state(output_dir)
            if completed_steps >= args.max_train_steps:
                break

        model.eval()
        losses = []
        for step, batch in enumerate(eval_dataloader):
            with torch.no_grad():
                outputs = model(**batch)

            loss = outputs.loss
            losses.append(
                accelerator.gather(loss.repeat(
                    args.per_device_eval_batch_size)))

        losses = torch.cat(losses)
        losses = losses[:len(eval_dataset)]
        try:
            perplexity = math.exp(torch.mean(losses))
        except OverflowError:
            perplexity = float("inf")

        logger.info(f"epoch {epoch}: perplexity: {perplexity}")

        if args.with_tracking:
            accelerator.log(
                {
                    "perplexity": perplexity,
                    "train_loss": total_loss,
                    "epoch": epoch,
                    "step": completed_steps
                }, )

        if args.push_to_hub and epoch < args.num_train_epochs - 1:
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(args.output_dir,
                                            save_function=accelerator.save)
            if accelerator.is_main_process:
                tokenizer.save_pretrained(args.output_dir)
                repo.push_to_hub(
                    commit_message=f"Training in progress epoch {epoch}",
                    blocking=False,
                    auto_lfs_prune=True)

        if args.checkpointing_steps == "epoch":
            output_dir = f"epoch_{epoch}"
            if args.output_dir is not None:
                output_dir = os.path.join(args.output_dir, output_dir)
            accelerator.save_state(output_dir)

    if args.output_dir is not None:
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(args.output_dir,
                                        save_function=accelerator.save)
        if accelerator.is_main_process:
            tokenizer.save_pretrained(args.output_dir)
            if args.push_to_hub:
                repo.push_to_hub(commit_message="End of training",
                                 auto_lfs_prune=True)

        with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
            json.dump({"perplexity": perplexity}, f)
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate
    )

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    model = model.module if hasattr(model, "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    # Check if saved optimizer or scheduler states exist
    if (
            args.model_name_or_path
            and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                            args.local_rank == -1 and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if 0 < args.max_steps < global_step:
                epoch_iterator.close()
                break
        if 0 < args.max_steps < global_step:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemple #20
0
def main():
    args = set_config()
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Prepare model
    # encoder = BertForQuestionAnswering.from_pretrained(args.bert_model)
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #
    # encoder.to(device)
    # encoder.eval()
    # #freeze bert
    # # for name, param in model.named_parameters():
    # #     if "bert" in name:
    # #         param.requires_grad = False
    #
    # model = GraphFusionNet(args)
    model = DFGN_Albert.from_pretrained(
        r'/DATA/disk1/baijinguo/BERT_Pretrained/albert-xlarge-v2',
        graph_config=args)
    model.to(device)
    model.train()

    # optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    optimizer = AdamW(model.parameters(), lr=args.learning_rate, eps=1e-8)

    global_step = 0

    if args.do_train:
        # load train data
        train_examples, train_features, train_graph = get_train_feature(
            args, args.do_train, tokenizer)
        train_examples_dict = example_dict(train_examples)
        train_data = DataIteratorPack(train_features,
                                      train_examples_dict,
                                      train_graph,
                                      args.train_batch_size,
                                      device,
                                      sent_limit=40,
                                      entity_limit=80,
                                      n_layers=args.n_layers,
                                      sequential=False)
        # (features, example_dict, graph_dict, bsz, device, sent_limit, entity_limit, n_layers = 2,
        # entity_type_dict = None, sequential = False,)
        # load dev data
        eval_examples, eval_features, eval_graph = get_train_feature(
            args, not args.do_train, tokenizer)
        eval_examples_dict = example_dict(eval_examples)
        eval_data = DataIteratorPack(eval_features,
                                     eval_examples_dict,
                                     eval_graph,
                                     args.predict_batch_size,
                                     device,
                                     sent_limit=40,
                                     entity_limit=80,
                                     n_layers=args.n_layers,
                                     sequential=False)
        with open(args.predict_file) as f:
            gold = json.load(f)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        cur_patience = 0
        VERBOSE_STEP = 100
        best_dev_F1 = None
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            # model.train()
            model.train()

            total_train_loss = [0] * 5

            for step, batch in enumerate(train_data):
                # batch = tuple(t.to(device) for t in batch)  # multi-gpu does scattering it-self
                input_ids = batch["context_idxs"]
                input_mask = batch["context_mask"]
                segment_ids = batch["segment_idxs"]
                # start_positions = batch["y1"]
                # end_positions = batch["y2"]
                # q_types = batch["q_type"]

                # context_encoding = encoder(input_ids, segment_ids, input_mask)
                #
                # # loss_list = model(context_encoding, batch=batch)
                # start, end, sp, Type, softmask, ent, yp1, yp2 = model(context_encoding, batch=batch, return_yp=True)

                start, end, sp, Type, softmask, ent, yp1, yp2 = model(
                    input_ids,
                    segment_ids,
                    input_mask,
                    batch=batch,
                    return_yp=True,
                    is_train=True)
                loss_list = compute_loss(batch, start, end, sp, Type, softmask,
                                         args)

                if args.gradient_accumulation_steps > 1:
                    loss_list = loss_list / args.gradient_accumulation_steps

                loss_list[0].backward()

                if (global_step + 1) % args.grad_accumulate_step == 0:
                    optimizer.step()
                    optimizer.zero_grad()

                global_step += 1

                for i, l in enumerate(loss_list):
                    if not isinstance(l, int):
                        total_train_loss[i] += l.item()

                if global_step % VERBOSE_STEP == 0:
                    print("-- In Epoch{}: ".format(epoch))
                    for i, l in enumerate(total_train_loss):
                        print("Avg-LOSS{}/batch/step: {}".format(
                            i, l / VERBOSE_STEP))
                    total_train_loss = [0] * 5

                # Save a trained model
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self

            train_data.refresh()
            if args.do_predict:

                eval_examples_dict = example_dict(eval_examples)
                eval_features_dict = example_dict(eval_features)

                logger.info("***** Running predictions *****")
                logger.info("  Num split examples = %d", len(eval_features))
                logger.info("  Batch size = %d", args.predict_batch_size)

                model.eval()
                all_results = []
                answer_dict = {}
                sp_dict = {}
                total_test_loss = [0] * 5
                logger.info("Start evaluating")
                for step, batch in enumerate(eval_data):
                    # batch = tuple(t.to(device) for t in batch)  # multi-gpu does scattering it-self
                    input_ids = batch["context_idxs"]
                    input_mask = batch["context_mask"]
                    segment_ids = batch["segment_idxs"]

                    if len(sp_dict) % 1000 == 0:
                        logger.info("Processing example: %d" %
                                    (len(all_results)))

                    with torch.no_grad():
                        start, end, sp, Type, softmask, ent, yp1, yp2 = model(
                            input_ids,
                            segment_ids,
                            input_mask,
                            batch=batch,
                            return_yp=True)
                        # context_encoding = encoder(input_ids, segment_ids, input_mask)
                        #
                        # # loss_list = model(context_encoding, batch=batch)
                        # start, end, sp, Type, softmask, ent, yp1, yp2 = model(context_encoding, batch=batch,
                        #                                                       return_yp=True)
                        loss_list = compute_loss(batch, start, end, sp, Type,
                                                 softmask, args)
                        Type = Type.argmax(dim=1)

                        # batch_start_logits, batch_end_logits, batch_types, sp = model(input_ids, segment_ids, input_mask, batch=batch)
                    for i, l in enumerate(loss_list):
                        if not isinstance(l, int):
                            total_test_loss[i] += l.item()

                    answer_dict_ = convert_to_tokens(
                        eval_examples_dict, eval_features_dict, batch['ids'],
                        yp1.data.cpu().numpy().tolist(),
                        yp2.data.cpu().numpy().tolist(),
                        Type.cpu().numpy())

                    answer_dict.update(answer_dict_)
                    predict_support_np = torch.sigmoid(
                        sp[:, :, 1]).data.cpu().numpy()
                    for i in range(predict_support_np.shape[0]):
                        cur_sp_pred = []
                        cur_id = batch['ids'][i]
                        for j in range(predict_support_np.shape[1]):
                            if j >= len(eval_examples_dict[cur_id].sent_names):
                                break
                            if predict_support_np[i, j] > args.sp_threshold:
                                cur_sp_pred.append(
                                    eval_examples_dict[cur_id].sent_names[j])
                        sp_dict.update({cur_id: cur_sp_pred})

                # for i, l in enumerate(total_train_loss):
                #     print("Avg-LOSS{}/batch/step: {}".format(i, l / len(eval_features)))

                prediction = {'answer': answer_dict, 'sp': sp_dict}
                output_answer_sp_file = os.path.join(
                    args.output_dir,
                    "predictions_answer_sp_{}.json".format(epoch))
                with open(output_answer_sp_file, 'w') as f:
                    json.dump(prediction, f)

                # record results
                metrics = eval(prediction, gold)
                for i, l in enumerate(total_train_loss):
                    metrics["LOSS{}".format(i)] = l / len(eval_features)
                    print("Avg-LOSS{}/batch/step: {}".format(
                        i, l / len(eval_features)))

                # fitlog.add_best_metric({"Test": metrics})

                metrics = evaluate(eval_examples_dict, answer_dict)
                print('hotpotqa | EM {:.4f} | F1 {:.4f}'.format(
                    metrics['exact_match'], metrics['f1']))
                eval_data.refresh()

                dev_F1 = metrics['f1']
                if best_dev_F1 is None or dev_F1 > best_dev_F1:
                    best_dev_F1 = dev_F1
                    output_model_file = os.path.join(args.output_dir,
                                                     "pytorch_model.bin")
                    model_to_save = model.module if hasattr(
                        model,
                        'module') else model  # Only save the model it-self

                    logger.info("model save in %s" % output_model_file)
                    # model_to_save.save_pretrained(output_model_file)
                    # tokenizer.save_pretrained(args.output_dir)
                    torch.save(model_to_save.state_dict(), output_model_file)
                    cur_patience = 0

                    # model = AlbertForQuestionAnswering.from_pretrained(args.output_dir, force_download=True)
                    # # tokenizer = AlbertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
                    # model.to(device)
                else:
                    cur_patience += 1
                    if cur_patience >= 3:
                        # for param_group in optimizer.param_groups:
                        #    param_group['lr'] /= 2.0
                        # if param_group['lr'] < 1e-8:
                        #    stop_train = True
                        break
Exemple #21
0
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        outputs = model(input_ids,
                        attention_mask=attention_mask,
                        start_positions=start_positions,
                        end_positions=end_positions)
        loss = outputs[0]
        curr_loss += curr_loss + loss.item()
        iteration += 1
        if iteration % plot_every == 0:
            all_loss.append(curr_loss / plot_every)
            curr_loss = 0
        if iteration % log_every == 0:
            print("iteration:", iteration, " loss", loss.item())
        loss.backward()
        optim.step()

model.eval()

torch.save(model.state_dict(), 'baseModel.pt')
nlp = spacy.blank("en")


def word_tokenize(sent):
    doc = nlp(sent)
    return [token.text for token in doc]


import collections

Exemple #22
0
def meta_train(tasks, model, args, device, method='random', meta_iters=10000, num_updates=5, meta_batch_size=5):
    """
    We'll start with binary classifiers (2-way classification)
    for step in range(num_steps):
        # Training
        for i in num_samples:
            task_batch_train := Sample tasks based on meta_batch_size (training set) (and task frequencies)
            for task in task_batch_train:
                forward
                loss
                backward

        # Meta-training
        if step % meta_every == 0:
            task_batch_test :=  Sample tasks not included in task_batch_train
                                meta_batch_test_size (> meta_batch_size, all?)
            for task in task_batch_test:
                forward
                loss
                backward

    params:
        - tasks
        - method: method of the task sampling sequential, custom probabilities or proportional to sqrt of data size
        - custom_task_ratio: default None only pass if custom task probabilities as sampling method
        - meta_iters: number of meta-training iterations
        - num_updates: number of updates in inner loop on same task_batch
        [NOT needed!?: num_classes: number of classes (N in N-way classification.). Default 2.]
        - meta_batch_size: number of N-way tasks per meta-batch (meta-update)
    """
    # Define logging
    os.makedirs(args.save_path, exist_ok=True)
    writer = SummaryWriter(
        os.path.join(args.save_path, 'runs', '{}'.format(datetime.now()).replace(":", "_")))

    header = '      Time      Task      Iteration      Loss      Accuracy'
    log_template = '{:>10} {:>25} {:10.0f} {:10.6f} {:10.6f}'
    test_template = 'Test mean: {}, Test std: {}'

    print(header)
    start = time.time()

    # Define optimizers, lr schedulers and loss function
    optimizer_bert = AdamW(params=model.proto_net.encoder.bert.parameters(), lr=args.bert_lr)
    optimizer = optim.Adam(params=chain(model.proto_net.encoder.mlp.parameters(),
                                   model.output_layer.parameters()),
                           lr=args.lr)
    scheduler_bert = get_cosine_schedule_with_warmup(optimizer_bert, 200, meta_iters)
    scheduler = get_cosine_schedule_with_warmup(optimizer, 0, meta_iters)
    # ProtoNets always have CrossEntropy loss due to softmax output
    cross_entropy = nn.CrossEntropyLoss()

    print('Loading Tokenizer..')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
    special_tokens_dict = {'additional_special_tokens': ["[MNT]", "[URL]"]}

    num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
    print('We have added', num_added_toks, 'tokens')
    model.proto_net.encoder.bert.resize_token_embeddings(len(tokenizer))
    # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.

    # setup task sampler and task model
    sampler = TaskSampler(tasks, method=method, custom_task_ratio=args.custom_task_ratio, supp_query_split=True)
    task_model = type(model)(args)
    task_model.proto_net.encoder.bert.resize_token_embeddings(len(tokenizer))

    iterations = 0
    # Iterate over the data
    train_iter = sampler.get_iter('train', tokenizer, batch_size=args.batch_size, shuffle=True)
    model.train()

    # setup validation task and episodes for evaluation
    val_task = get_validation_task(args)
    episodes = torch.load(args.episodes)

    # dummy data to overwrite old values of task model output layer
    dummy_w = torch.randn((args.mlp_dims[-1], 2))
    dummy_b = torch.randn(2)

    average_query_loss = 0
    best_query_loss = 1e+9
    best_test_mean = -1
    best_test_last = -1
    convergence_tolerance_cnt = 0
    # outer loop (meta-iterations)
    for i in range(meta_iters):
        grads = []
        task_losses_inner = {}
        task_accuracies_inner = {}
        task_losses_outer = {}
        task_accuracies_outer = {}
        # inner loop (sample different tasks)
        for task_sample in range(meta_batch_size):
            # clone original model
            task_model.proto_net.load_state_dict(model.proto_net.state_dict())
            task_model.initialize_classifier(nn.Parameter(dummy_w), nn.Parameter(dummy_b), hard_replace=True)
            task_model.to(device)
            task_model.train()

            # new optimizer for every new task model
            task_optimizer_bert = optim.SGD(params=task_model.proto_net.encoder.bert.parameters(), lr=args.bert_lr)
            task_optimizer = optim.SGD(params=chain(task_model.proto_net.encoder.mlp.parameters(),
                                                    task_model.output_layer.parameters()),
                                       lr=args.inner_lr)

            # prepare support and query set
            batch = next(train_iter)
            support = batch[:3]
            query = batch[3:]

            # setup output layer (via meta-model's prototype network)
            proto_embeddings = model.proto_net(support[0].to(device), attention_mask=support[2].to(device))
            prototypes = model.proto_net.calculate_centroids((proto_embeddings, support[1]), sampler.get_num_classes())
            W, b = task_model.calculate_output_params(prototypes.detach())
            task_model.initialize_classifier(W, b)

            # train some iterations on support set
            for update in range(num_updates):
                task_optimizer_bert.zero_grad()
                task_optimizer.zero_grad()
                predictions = task_model(support[0].to(device), attention_mask=support[2].to(device))
                task_loss = cross_entropy(predictions, support[1].long().squeeze().to(device))
                task_loss.backward()
                task_optimizer.step()
                task_optimizer_bert.step()

            # record task losses and accuracies for logging
            task_losses_inner[sampler.get_name()] = task_loss.item()
            task_accuracies_inner[sampler.get_name()] = sampler.calculate_accuracy(predictions, support[1].to(device))

            # trick to add prototypes back to computation graph
            W = 2 * prototypes + (W - 2 * prototypes).detach()
            b = -prototypes.norm(dim=1)**2 + (b + prototypes.norm(dim=1)**2).detach()
            task_model.initialize_classifier(W, b, hard_replace=True)

            # calculate gradients for meta update on the query set
            predictions = task_model(query[0].to(device), attention_mask=query[2].to(device))
            query_loss = cross_entropy(predictions, query[1].long().squeeze().to(device))
            query_loss.backward()

            # record task losses and accuracies for logging
            task_losses_outer[sampler.get_name()] = query_loss.item()
            task_accuracies_outer[sampler.get_name()] = sampler.calculate_accuracy(predictions, query[1].to(device))
            average_query_loss += query_loss.item()

            # register W and b parameters again to avoid error in weight update
            W = nn.Parameter(W)
            b = nn.Parameter(b)
            task_model.initialize_classifier(W, b, hard_replace=True)

            # save gradients of first task model
            if task_sample == 0:
                for param in task_model.parameters():
                    if param.requires_grad and param.grad is not None:
                        grads.append(param.grad.clone())
            # add the gradients of all task samples
            else:
                p = 0
                for param in task_model.parameters():
                    if param.requires_grad and param.grad is not None:
                        grads[p] += param.grad.clone()
                        p += 1

        # perform meta update
        # first load/add the calculated gradients in the meta-model
        # (already contains gradients from prototype calculation)
        p = 0
        for param in model.parameters():
            if param.requires_grad and param.grad is not None:
                param.grad += grads[p]
                p += 1
        # update model parameters according to the gradients from inner loop (clear gradients afterwards)
        optimizer.step()
        optimizer_bert.step()
        scheduler.step()
        scheduler_bert.step()
        optimizer.zero_grad()
        optimizer_bert.zero_grad()

        iterations += 1
        if iterations % args.log_every == 0:
            average_query_loss /= (args.log_every*meta_batch_size)
            iter_loss = sum(task_losses_outer.values()) / len(task_losses_outer.values())
            iter_acc = sum(task_accuracies_outer.values()) / len(task_accuracies_outer.values())
            writer.add_scalar('Meta_Average/Loss/outer'.format(sampler.get_name()), iter_loss, iterations)
            writer.add_scalar('Meta_Average/Accuracy/outer'.format(sampler.get_name()), iter_acc, iterations)
            for t in tasks:
                task_name = t.get_name()
                if task_name in task_losses_inner.keys():
                    writer.add_scalar('{}/Loss/inner'.format(task_name), task_losses_inner[task_name], iterations)
                    writer.add_scalar('{}/Accuracy/inner'.format(task_name), task_accuracies_inner[task_name], iterations)
                    writer.add_scalar('{}/Loss/outer'.format(task_name), task_losses_outer[task_name], iterations)
                    writer.add_scalar('{}/Accuracy/outer'.format(task_name), task_accuracies_outer[task_name], iterations)
            print(log_template.format(
                str(timedelta(seconds=int(time.time() - start))),
                sampler.get_name(),
                iterations,
                iter_loss,
                iter_acc))

            # save best snapshot
            if average_query_loss < best_query_loss:
                best_query_loss = average_query_loss
                average_query_loss = 0
                snapshot_prefix = os.path.join(args.save_path, 'best_query')
                snapshot_path = (
                        snapshot_prefix +
                        '_loss_{:.5f}_iter_{}_model.pt'
                ).format(best_query_loss, iterations)
                model.save_model(snapshot_path)
                # Keep only the best snapshot
                for f in glob.glob(snapshot_prefix + '*'):
                    if f != snapshot_path:
                        os.remove(f)

        # evaluate in k shot fashion
        if iterations % args.eval_every == 0:
            task_model.proto_net.load_state_dict(model.proto_net.state_dict())
            task_model.initialize_classifier(nn.Parameter(dummy_w), nn.Parameter(dummy_b), hard_replace=True)
            test_mean, test_std = k_shot_testing(task_model, episodes, val_task, device, num_updates=args.inner_updates,
                                                 num_test_batches=args.num_test_batches)
            writer.add_scalar('{}/Acc'.format(val_task.get_name()), test_mean, iterations)
            writer.add_scalar('{}/STD'.format(val_task.get_name()), test_std, iterations)
            print(test_template.format(test_mean, test_std), flush=True)
            if test_mean > best_test_mean:
                best_test_mean = test_mean
                snapshot_prefix = os.path.join(args.save_path, 'best_test_{}'.format(val_task.get_name()))
                snapshot_path = (
                        snapshot_prefix +
                        '_acc_{:.5f}_iter_{}_model.pt'
                ).format(best_test_mean, iterations)
                model.save_model(snapshot_path)
                # Keep only the best snapshot
                for f in glob.glob(snapshot_prefix + '*'):
                    if f != snapshot_path:
                        os.remove(f)
            
            if test_mean > best_test_last:
                best_test_last = best_test_mean
                convergence_tolerance_cnt = 0
            else:
                convergence_tolerance_cnt += 1

            if convergence_tolerance_cnt == args.convergence_tolerance:
                break


        # saving redundant parameters
        # Save model checkpoints.
        if iterations % args.save_every == 0:
            iter_loss = sum(task_losses_outer.values()) / len(task_losses_outer.values())
            snapshot_prefix = os.path.join(args.save_path, 'snapshot')
            snapshot_path = (
                    snapshot_prefix +
                    '_iter_{}_loss_{}_model.pt'
            ).format(iterations, iter_loss)
            logging.debug('Saving model...')
            model.save_model(snapshot_path)
            # Keep only the last snapshot
            for f in glob.glob(snapshot_prefix + '*'):
                if f != snapshot_path:
                    os.remove(f)

    writer.close()
Exemple #23
0
def train(model, tokenizer, train_dataloader, validation_dataloader,
          index_to_label, pad_token_dict, doc_start_ind_dict, device):
    def calculate_loss(lm_logits, b_labels, b_input_mask, cls_labels,
                       index_to_label, doc_start_ind_dict, loss_fct):
        batch_size = lm_logits.shape[0]
        logits_collected = []
        labels_collected = []
        for b in range(batch_size):
            logits_ind = lm_logits[b, :, :]  # seq_len x |V|
            labels_ind = b_labels[b, :]  # seq_len
            mask = b_input_mask[b, :] > 0
            maski = mask.unsqueeze(-1).expand_as(logits_ind)
            # unpad_seq_len x |V|
            logits_pad_removed = torch.masked_select(logits_ind, maski).view(
                -1, logits_ind.size(-1))
            labels_pad_removed = torch.masked_select(labels_ind,
                                                     mask)  # unpad_seq_len

            doc_start_ind = doc_start_ind_dict[index_to_label[
                cls_labels[b].item()]]
            shift_logits = logits_pad_removed[doc_start_ind -
                                              1:-1, :].contiguous()
            shift_labels = labels_pad_removed[doc_start_ind:].contiguous()
            # Flatten the tokens
            logits_collected.append(
                shift_logits.view(-1, shift_logits.size(-1)))
            labels_collected.append(shift_labels.view(-1))

        logits_collected = torch.cat(logits_collected, dim=0)
        labels_collected = torch.cat(labels_collected, dim=0)
        loss = loss_fct(logits_collected, labels_collected)
        return loss

    optimizer = AdamW(
        model.parameters(),
        lr=5e-4,  # args.learning_rate - default is 5e-5, our notebook had 2e-5
        eps=1e-8  # args.adam_epsilon  - default is 1e-8.
    )

    loss_fct = CrossEntropyLoss()
    sample_every = 100
    warmup_steps = 1e2
    epochs = 5
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=total_steps)
    seed_val = 42
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)

    training_stats = []
    total_t0 = time.time()

    for epoch_i in range(0, epochs):
        print("", flush=True)
        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs),
              flush=True)
        print('Training...', flush=True)
        t0 = time.time()
        total_train_loss = 0
        model.train()

        for step, batch in enumerate(train_dataloader):
            if step % sample_every == 0 and not step == 0:
                elapsed = format_time(time.time() - t0)
                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(
                    step, len(train_dataloader), elapsed),
                      flush=True)
                model.eval()
                lbl = random.choice(list(index_to_label.values()))
                temp_list = ["<|labelpad|>"] * pad_token_dict[lbl]
                if len(temp_list) > 0:
                    label_str = " ".join(
                        lbl.split("_")) + " " + " ".join(temp_list)
                else:
                    label_str = " ".join(lbl.split("_"))
                text = tokenizer.bos_token + " " + label_str + " <|labelsep|> "
                sample_outputs = model.generate(input_ids=tokenizer.encode(
                    text, return_tensors='pt').to(device),
                                                do_sample=True,
                                                top_k=50,
                                                max_length=200,
                                                top_p=0.95,
                                                num_return_sequences=1)
                for i, sample_output in enumerate(sample_outputs):
                    print("{}: {}".format(i, tokenizer.decode(sample_output)),
                          flush=True)
                model.train()

            b_input_ids = batch[0].to(device)
            b_labels = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            cls_labels = batch[2].to(device)

            model.zero_grad()

            outputs = model(b_input_ids,
                            token_type_ids=None,
                            attention_mask=b_input_mask,
                            labels=b_labels)

            loss = calculate_loss(outputs[1], b_labels, b_input_mask,
                                  cls_labels, index_to_label,
                                  doc_start_ind_dict, loss_fct)
            # loss = outputs[0]
            total_train_loss += loss.item()

            loss.backward()
            optimizer.step()
            scheduler.step()

        # Calculate the average loss over all of the batches.
        avg_train_loss = total_train_loss / len(train_dataloader)

        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)

        print("", flush=True)
        print("  Average training loss: {0:.2f}".format(avg_train_loss),
              flush=True)
        print("  Training epcoh took: {:}".format(training_time), flush=True)

        # ========================================
        #               Validation
        # ========================================
        # After the completion of each training epoch, measure our performance on
        # our validation set.

        print("", flush=True)
        print("Running Validation...", flush=True)

        t0 = time.time()

        model.eval()

        total_eval_loss = 0
        nb_eval_steps = 0

        # Evaluate data for one epoch
        for batch in validation_dataloader:
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_labels = batch[0].to(device)
            cls_labels = batch[2].to(device)

            with torch.no_grad():
                outputs = model(b_input_ids,
                                token_type_ids=None,
                                attention_mask=b_input_mask,
                                labels=b_labels)

            # Accumulate the validation loss.
            loss = calculate_loss(outputs[1], b_labels, b_input_mask,
                                  cls_labels, index_to_label,
                                  doc_start_ind_dict, loss_fct)
            # loss = outputs[0]
            total_eval_loss += loss.item()

        # Calculate the average loss over all of the batches.
        avg_val_loss = total_eval_loss / len(validation_dataloader)

        # Measure how long the validation run took.
        validation_time = format_time(time.time() - t0)

        print("  Validation Loss: {0:.2f}".format(avg_val_loss), flush=True)
        print("  Validation took: {:}".format(validation_time), flush=True)

        # Record all statistics from this epoch.
        training_stats.append({
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Training Time': training_time,
            'Validation Time': validation_time
        })

    print("", flush=True)
    print("Training complete!", flush=True)

    print("Total training took {:} (h:mm:ss)".format(
        format_time(time.time() - total_t0)),
          flush=True)
    return model
Exemple #24
0
def run_bert(params):
    """Google Bert Classifier
    """
    if torch.cuda.is_available():
        device = int(get_freer_gpu())
        torch.cuda.set_device(device)
    else:
        device = torch.device('cpu')

    print('Loading Datasets and oversample training data...')
    train_df = pd.read_csv(params['data_dir'] + 'train.tsv',
                           sep='\t',
                           na_values='x')

    # oversample the minority class
    if params['balance']:
        label_count = Counter(train_df.label)
        for label_tmp in label_count:
            sample_num = label_count.most_common(
                1)[0][1] - label_count[label_tmp]
            if sample_num == 0:
                continue
            train_df = pd.concat([
                train_df, train_df[train_df.label == label_tmp].sample(
                    int(sample_num * params['balance_ratio']), replace=True)
            ])
        train_df = train_df.reset_index()  # to prevent index key error

    valid_df = pd.read_csv(params['data_dir'] + 'valid.tsv',
                           sep='\t',
                           na_values='x')
    test_df = pd.read_csv(params['data_dir'] + 'test.tsv',
                          sep='\t',
                          na_values='x')
    data_df = [train_df, valid_df, test_df]
    # We need to add special tokens at the beginning and end of each sentence for BERT to work properly
    for doc_df in data_df:
        doc_df.text = doc_df.text.apply(lambda x: '[CLS] ' + x + ' [SEP]')

    print('Padding Datasets...')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                              do_lower_case=True)
    for doc_df in data_df:
        doc_df.text = doc_df.text.apply(lambda x: tokenizer.tokenize(x))

    # convert to indices and pad the sequences
    for doc_df in data_df:
        doc_df.text = doc_df.text.apply(
            lambda x: pad_sequences([tokenizer.convert_tokens_to_ids(x)],
                                    maxlen=params['max_len'],
                                    dtype="long")[0])

    # create attention masks
    for doc_df in data_df:
        attention_masks = []
        for seq in doc_df.text:
            seq_mask = [float(idx > 0) for idx in seq]
            attention_masks.append(seq_mask)
        doc_df['masks'] = attention_masks

    # format train, valid, test
    train_inputs = torch.tensor(data_df[0].text)
    train_labels = torch.tensor(data_df[0].label)
    train_masks = torch.tensor(data_df[0].masks)
    valid_inputs = torch.tensor(data_df[1].text)
    valid_labels = torch.tensor(data_df[1].label)
    valid_masks = torch.tensor(data_df[1].masks)
    test_inputs = torch.tensor(data_df[2].text)
    test_labels = torch.tensor(data_df[2].label)
    test_masks = torch.tensor(data_df[2].masks)

    batch_size = params['batch_size']

    train_data = TensorDataset(train_inputs, train_masks, train_labels)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=batch_size,
                                  num_workers=os.cpu_count())
    valid_data = TensorDataset(valid_inputs, valid_masks, valid_labels)
    valid_sampler = SequentialSampler(valid_data)
    valid_dataloader = DataLoader(valid_data,
                                  sampler=valid_sampler,
                                  batch_size=batch_size)
    test_data = TensorDataset(test_inputs, test_masks, test_labels)
    test_sampler = SequentialSampler(test_data)
    test_dataloader = DataLoader(test_data,
                                 sampler=test_sampler,
                                 batch_size=batch_size)

    # load the pretrained model
    print('Loading Pretrained Model...')
    model = BertForSequenceClassification.from_pretrained(
        'bert-base-uncased', num_labels=params['num_label'])
    model.to(device)

    # organize parameters
    param_optimizer = list(model.named_parameters())
    if params['freeze']:
        no_decay = ['bias', 'bert']  # , 'bert' freeze all bert parameters
    else:
        no_decay = ['bias']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        params['decay_rate']
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters, lr=params['lr'])
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=params['warm_steps'],
        num_training_steps=params['train_steps'])
    wfile = open('./results/bert_results.txt', 'a')
    wfile.write(params['data_name'] + '_________________\n')

    # Training
    print('Training the model...')
    for epoch in trange(params['epochs'], desc='Epoch'):
        model.train()
        # Tracking variables
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0

        # train batch
        for step, batch in enumerate(train_dataloader):
            # Add batch to GPU
            batch = tuple(t.to(device) for t in batch)
            # Unpack the inputs from our dataloader
            b_input_ids, b_input_mask, b_labels = batch
            # Clear out the gradients (by default they accumulate)
            optimizer.zero_grad()
            # Forward pass
            outputs = model(b_input_ids,
                            token_type_ids=None,
                            attention_mask=b_input_mask,
                            labels=b_labels)
            # backward pass
            outputs[0].backward()
            # outputs.backward()
            # Update parameters and take a step using the computed gradient
            optimizer.step()
            scheduler.step()

            # Update tracking variables
            tr_loss += outputs[0].item()
            nb_tr_examples += b_input_ids.size(0)
            nb_tr_steps += 1
        print("Train loss: {}".format(tr_loss / nb_tr_steps))
        '''Validation'''
        best_valid_f1 = 0.0
        # Put model in evaluation mode to evaluate loss on the validation set
        model.eval()
        # tracking variables
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0

        # batch eval
        y_preds = []
        for batch in valid_dataloader:
            # Add batch to GPU
            batch = tuple(t.to(device) for t in batch)
            # Unpack the inputs from our dataloader
            b_input_ids, b_input_mask, b_labels = batch
            # Telling the model not to compute or store gradients, saving memory and speeding up validation
            with torch.no_grad():
                # Forward pass, calculate logit predictions
                outputs = model(b_input_ids,
                                token_type_ids=None,
                                attention_mask=b_input_mask)
            # Move logits and labels to CPU
            logits = outputs[0].detach().cpu().numpy()
            # record the prediction
            pred_flat = np.argmax(logits, axis=1).flatten()
            y_preds.extend(pred_flat)

            label_ids = b_labels.to('cpu').numpy()
            tmp_eval_accuracy = flat_accuracy(logits, label_ids)

            eval_accuracy += tmp_eval_accuracy
            nb_eval_steps += 1

        print("Validation Accuracy: {}".format(eval_accuracy / nb_eval_steps))
        # evaluate the validation f1 score
        f1_m_valid, f1_w_valid = flat_f1(y_preds, valid_df.label)
        if f1_w_valid > best_valid_f1:
            print('Test....')
            best_valid_f1 = f1_w_valid
            print('Epoch {0}, valid f1 score {1}'.format(epoch, best_valid_f1))
            y_preds = []
            y_probs = []

            # test if valid gets better results
            for batch in test_dataloader:
                batch = tuple(t.to(device) for t in batch)
                b_input_ids, b_input_mask, b_labels = batch
                with torch.no_grad():
                    outputs = model(b_input_ids,
                                    token_type_ids=None,
                                    attention_mask=b_input_mask)
                probs = torch_func.softmax(outputs[0], dim=1)
                probs = probs.detach().cpu().numpy()
                pred_flat = np.argmax(probs, axis=1).flatten()
                y_preds.extend(pred_flat)
                y_probs.extend([item[1] for item in probs])

            # save the predicted results
            wfile.write('Epoch: {}.........................\n'.format(epoch))
            wfile.write(
                str(
                    f1_score(y_pred=y_preds,
                             y_true=test_df.label,
                             average='weighted')) + '\n')
            report = classification_report(y_pred=y_preds,
                                           y_true=test_df.label,
                                           digits=3)
            print(report)
            wfile.write(report + '\n')
            wfile.write('.........................\n')
            wfile.write('\n')
Exemple #25
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = (RandomSampler(train_dataset) if args.local_rank == -1 else
                     DistributedSampler(train_dataset))
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = (
            args.max_steps //
            (len(train_dataloader) // args.gradient_accumulation_steps) + 1)
    else:
        t_total = (len(train_dataloader) // args.gradient_accumulation_steps *
                   args.num_train_epochs)

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to global_step of last saved checkpoint from model path
        try:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        except ValueError:
            global_step = 0
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info(
            "  Will skip the first %d steps in the first epoch",
            steps_trained_in_current_epoch,
        )

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3],
            }
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2]
                    if args.model_type in ["bert", "xlnet", "albert"] else None
                )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= args.gradient_accumulation_steps and
                (step + 1) == len(epoch_iterator)):
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if (args.local_rank in [-1, 0] and args.logging_steps > 0
                        and global_step % args.logging_steps == 0):
                    logs = {}
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            eval_key = "eval_{}".format(key)
                            logs[eval_key] = value

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logging_loss = tr_loss

                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    print(json.dumps({**logs, **{"step": global_step}}))

                if (args.local_rank in [-1, 0] and args.save_steps > 0
                        and global_step % args.save_steps == 0):
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
def train(dataset, save_dir, args, balance=False, debugging=False):
    """ prepare dataset and saving directory """
    if not debugging: new_dir(save_dir)
    train, dev, test = dataset['train'], dataset['dev'], dataset['test']
    num_labels = train.num_labels()
    n = len(dataset['train'])
    """ create dataloader """
    sampler = SequentialSampler(train)
    if balance:
        frequencies = {}
        for pair in train:
            if pair[1].item() not in frequencies:
                frequencies[pair[1].item()] = 0
            frequencies[pair[1].item()] += 1
        weights = []
        for pair in train:
            weights.append(1 / frequencies[pair[1].item()])
        sampler = WeightedRandomSampler(weights=weights,
                                        num_samples=len(train))
    train_dataloader = DataLoader(train,
                                  sampler=sampler,
                                  batch_size=args['batch_size'])
    """ create model and prepare optimizer """
    #model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels)
    # CONTINUE TRAINING WIKI MODEL!!!
    model = BertForSequenceClassification.from_pretrained("../wiki-epoch-5")
    #train_dataloader = DataLoader(train, batch_size=args['batch_size'], shuffle=True)
    optimizer = AdamW(model.parameters(), lr=args['lr'])
    total_steps = len(train_dataloader) * args[
        'epochs']  # number of batches * number of epochs
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=0,
                                                num_training_steps=total_steps)
    """ train model """
    print("starting training")
    train_start = time.time()
    last_save_time = time.time()  # time since last save
    last_save_count = 0
    for epoch in range(1 if debugging else args['epochs']):
        print("\nstarting epoch {} out of {}".format(epoch + 1,
                                                     args['epochs']))
        print("time taken so far: {}\n".format(time.time() - train_start))
        model.train()  # turn on train mode (turned off in evaluate)
        total_loss = 0.
        curr = 0
        predictions = np.zeros(args['batch_size'] if debugging else
                               n)  # used for confusion matrix
        truth = np.zeros(args['batch_size'] if debugging else n)
        epoch_time = time.time()
        for (x_batch,
             y_batch) in train_dataloader:  # different shuffle each time
            optimizer.zero_grad()
            output = model(x_batch, labels=y_batch)
            loss, preds = output[0], output[1]
            predictions[curr:min(n, curr + args['batch_size'])] = torch.argmax(
                preds, axis=1)
            truth[curr:min(n, curr + args['batch_size'])] = y_batch
            total_loss += loss.item()
            curr += args['batch_size']
            loss.backward()  # loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args['clip_grad_norm'])
            optimizer.step()
            scheduler.step()
            if debugging: break  # only 1 batch when debugging
            if time.time() - last_save_time > 3600:  # 3600 seconds = 1 hour
                last_save_time = time.time()
                last_save_count += 1
                print("{:4f}% through training".format(100 * curr / n))
                print("time taken so far: {}\n".format(time.time() -
                                                       train_start))
                if not debugging:
                    path = os.path.join(
                        save_dir, "save-count-{}".format(last_save_count))
                    new_dir(path)
                    model.save_pretrained(path)

        total_loss /= len(train)
        epoch_time = time.time() - epoch_time
        total_time = time.time() - train_start
        accuracy = np.mean(predictions == truth)

        # print info from this training epoch
        print(
            'train - epoch {:3d} | accuracy {:5.5f} | {:5d} samples | lr {:02.3f} | epoch time {:5.2f} | '
            'total time {:5.2f} | mean loss {:5.5f}'.format(
                epoch + 1, accuracy, len(train),
                scheduler.get_lr()[0], epoch_time, total_time, total_loss))

        evaluate(model,
                 dev,
                 num_labels,
                 batch_size=args['batch_size'],
                 debugging=debugging)  # validation
        if not debugging:
            last_save_count += 1
            path = os.path.join(save_dir, "final-{}".format(last_save_count))
            new_dir(path)
            model.save_pretrained(path)
    """ evaluate and save final model """
    print("training complete, evaluating on test dataset")
    evaluate(model, test, num_labels)
    if not debugging:
        path = os.path.join(save_dir, "final")
        new_dir(path)
        model.save_pretrained(path)
    """ 
    Note: see https://stackoverflow.com/questions/42703500/
    best-way-to-save-a-trained-model-in-pytorch — the best way to save a model
    is to save the state, then to load using
    new_model = TheModelClass(*args, **kwargs)
    new_model.load_state_dict(torch.load(path))
    """
    return model
Exemple #27
0
def train_ft(args, model_ft, train_dataset):
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.mini_batch_size * max(1, args.n_gpu)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    args.num_train_epochs = 1
    t_total = len(train_dataloader
                  ) // args.gradient_accumulation_steps * args.num_train_epochs

    if args.warmup_proportion > 0:
        args.warmup_steps = int(t_total * args.warmup_proportion)

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in list(model_ft.named_parameters())
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in list(model_ft.named_parameters())
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model_ft, optimizer = amp.initialize(model_ft,
                                             optimizer,
                                             opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model_ft = torch.nn.DataParallel(model_ft)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model_ft = torch.nn.parallel.DistributedDataParallel(
            model_ft,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    tr_loss, logging_loss = 0.0, 0.0

    model_ft.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])

    set_seed(args)
    logger.info("******* train ft *************")
    for _ in range(1):
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iter(loss=X.XXX, lr=X.XXXXXXXX)",
                              disable=args.local_rank not in [-1, 0])

        for step, batch in enumerate(epoch_iterator):
            model_ft.train()
            batch = tuple(t.to(args.device) for t in batch)

            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "labels": batch[3],
                "label_mask": batch[4],
            }

            outputs = model_ft(**inputs)
            loss = outputs

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                epoch_iterator.set_description(
                    'Iter (loss=%5.3f) lr=%9.7f' %
                    (loss.item(), scheduler.get_lr()[0]))
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model_ft.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model_ft.zero_grad()
                global_step += 1

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return model_ft
class EncoderMemory(Learner):
    def __init__(self, config, **kwargs):
        """
        Baseline models: sequential and multitask setup.
        """
        self.key_dim = config.learner.key_dim  # needs to be before super init
        super().__init__(config, **kwargs)

        self.lr = config.learner.lr
        self.type = config.learner.type
        self.n_epochs = config.training.epochs
        self.log_freq = config.training.log_freq
        self.loss_fn = nn.CrossEntropyLoss()

        self.n_neighbours = config.learner.n_neighbours

        # self.memory = MemoryStore(memory_size=config.learner.memory_size,
        #                           key_dim=self.key_dim,
        #                           device=self.device)
        self.logger.info(f"Instantiated memory of size {config.learner.memory_size} with key dimension {self.key_dim}")
        self.encoder = TransformerRLN(model_name=config.learner.model_name,
                                      max_length=config.data.max_length,
                                      device=self.device)
        self.logger.info("Loaded {} as model".format(self.encoder.__class__.__name__))
        # self.decoder = LSTMDecoder(key_size=config.learner.key_dim, embedding_size=TRANSFORMER_HDIM).to(self.device)
        dimensions = [TRANSFORMER_HDIM] + list(self.key_dim)
        if self.config.learner.full_reconstruction:
            self.key_encoders = [relu(nn.Linear(dim, next_dim).to(self.device)) for dim, next_dim in zip(dimensions, dimensions[1:])]
            self.key_decoders = [relu(nn.Linear(key_dim, TRANSFORMER_HDIM).to(self.device)) for key_dim in self.key_dim] 
        else:
            self.key_encoders = [nn.Linear(dim, next_dim).to(self.device) for dim, next_dim in zip(dimensions, dimensions[1:])] 
            self.key_decoders = [nn.Linear(next_dim, dim).to(self.device) for dim, next_dim in zip(dimensions, dimensions[1:])]
        self.logger.info(f"Key encoders: {self.key_encoders} -- key decoders: {self.key_decoders}")
        self.classifier = nn.Linear(TRANSFORMER_HDIM, config.data.n_classes).to(self.device)
        # self.key_classifiers = nn.Linear(self.key_dim, config.data.n_classes).to(self.device)
        self.key_classifiers = [nn.Linear(dim, config.data.n_classes).to(self.device) for dim in self.key_dim]
        self.logger.info(f"Key classifiers: {self.key_classifiers}")

        self.optimizer = AdamW([p for p in
                                itertools.chain(self.encoder.parameters(),
                                                *[key_encoder.parameters() for key_encoder in self.key_encoders],
                                                *[key_decoder.parameters() for key_decoder in self.key_decoders],
                                                *[key_classifier.parameters() for key_classifier in self.key_classifiers],
                                                self.classifier.parameters()
                                                )
                                if p.requires_grad],
                                lr=self.lr)

    def training(self, datasets, **kwargs):
        # train_datasets = {dataset_name: dataset for dataset_name, dataset in zip(datasets["order"], datasets["train"])}
        train_datasets = datasets_dict(datasets["train"], datasets["order"])
        val_datasets = datasets_dict(datasets["val"], datasets["order"])

        samples_per_task = self.config.learner.samples_per_task
        order = self.config.task_order if self.config.task_order is not None else datasets["order"]
        n_samples = [samples_per_task] * len(order) if samples_per_task is None else samples_per_task
        dataset = get_continuum(train_datasets, order=order, n_samples=n_samples)
        dataloader = DataLoader(dataset, batch_size=self.mini_batch_size, shuffle=False)

        for text, labels, datasets in dataloader:
            output = self.training_step(text, labels)
            predictions = model_utils.make_prediction(output["logits"].detach())
            # for logging
            key_predictions = [
                model_utils.make_prediction(key_logits.detach()) for key_logits in output["key_logits"]
            ]
            # self.logger.debug(f"accuracy prediction from key embedding: {key_metrics['accuracy']}")

            self.update_tracker(output, predictions, key_predictions, labels)
            online_metrics = model_utils.calculate_metrics(predictions.tolist(), labels.tolist())
            self.metrics["online"].append({
                "accuracy": online_metrics["accuracy"],
                "examples_seen": self.examples_seen(),
                "task": datasets[0]  # assumes whole batch is from same task
            })
            if self.current_iter % self.log_freq == 0:
                self.log()
                self.write_metrics()
            if self.current_iter % self.validate_freq == 0:
                self.validate(val_datasets, n_samples=self.config.training.n_validation_samples)
            self.current_iter += 1

    def training_step(self, text, labels):
        self.set_train()
        labels = torch.tensor(labels).to(self.device)
        output = self.forward(text, labels)

        # compute losses
        loss = self.loss_fn(output["logits"], labels)
        key_losses = [self.loss_fn(key_logits, labels) for key_logits in output["key_logits"]]

        # update here
        self.optimizer.zero_grad()

        # backward passes
        for reconstruction_error in output["reconstruction_errors"]:
            reconstruction_error.backward(retain_graph=True)
        for key_loss in key_losses:
            key_loss.backward(retain_graph=True)
        loss.backward()

        self.optimizer.step()

        loss = loss.item()
        key_losses = [key_loss.item() for key_loss in key_losses]
        # key_loss = 0
        self.logger.debug(
            f"Loss: {loss} -- key_loss: {key_losses}"
            f" -- reconstruction errors: {[re.item() for re in output['reconstruction_errors']]}"
        )
        # self.logger.debug(f"Key Loss: {key_loss}")
        return {
            "logits": output["logits"],
            "key_logits": output["key_logits"],
            "loss": loss,
            "key_losses": key_losses,
            "reconstruction_errors": [reconstruction_error.item() for reconstruction_error in output["reconstruction_errors"]]
        }

    def forward(self, text, labels, update_memory=True):
        """
        Forward pass using memory architecture.

        Parameters
        ---
        text
        labels
        update_memory: bool
            If false, don't update memory data
        """
        input_dict = self.encoder.encode_text(text)
        text_embedding = self.encoder(input_dict)
        key_embeddings = self.encode_keys(text_embedding.detach())
        reconstructions = [key_decoder(key_embedding) for key_decoder, key_embedding in zip(self.key_decoders, key_embeddings)]
        if self.config.learner.full_reconstruction:
            reconstruction_errors = [
                # reconstruction goes to original text embedding
                ((text_embedding.detach() - reconstruction) ** 2).mean() for reconstruction in reconstructions
            ]
        else:
            reconstruction_errors = [
                ((real_embedding.detach() - reconstruction) ** 2).mean()
                for real_embedding, reconstruction in zip([text_embedding] + key_embeddings[:-1], reconstructions)
            ]

        # query_result = self.memory.query(key_embedding, self.n_neighbours)
        # prediction_embedding = self.decoder(text_embedding, query_result)
        logits = self.classifier(text_embedding)
        key_logits = [key_classifier(key_embedding) for key_classifier, key_embedding in zip(self.key_classifiers, key_embeddings)]

        # if update_memory:
        #     self.memory.add_entry(embeddings=key_embedding.detach(), labels=labels, query_result=query_result)

        return {
            "logits": logits,
            "key_logits": key_logits,
            "reconstructions": reconstructions,
            "reconstruction_errors": reconstruction_errors
        }

    def encode_keys(self, embedding):
        """
        Encode an embedding into key embeddings.
        Each key embedding is compressed using the previous key's embedding.

        Parameters
        ---
        embedding: Tensor, shape (BATCH, TRANSFORMER_HDIM)
            Embedding to be mapped to key embeddings.

        Returns
        ---
        List[tensor], one element for each key_encoder, each tensor of shape BATCH, KEY_DIM corresponding
        to that encoder, specified by self.key_dims.
        """
        key_embeddings = []
        for key_encoder in self.key_encoders:
            # TODO: use embedding.detach() to block gradients between key embedding layers?
            embedding = key_encoder(embedding)
            key_embeddings.append(embedding)
        return key_embeddings

    # def decode_keys(self, key_embeddings):
    #     """
    #     Parameters
    #     ---
    #     key_embeddings: List[tensor]
    #         Each tensor is a key embedding of shape (BATCH, key_size), sizes in the same order as
    #         self.key_dim.
        
    #     Returns
    #     ---
    #     List[tensor], one element for each key_decoder, each tensor of shape (BATCH, prev_dim),
    #     where prev_dim is the dimension of the previous key encoder. The first element should have the 
    #     """
    #     # TODO: instead of going from key
    #     decoded = [key_decoder(key_embedding) for key_decoder, key_embedding in zip(self.key_decoders, self.key_embeddings)]
    #     pass

    def reset_tracker(self):
        """Initializes dictionary that stores performance data during training for logging purposes."""
        self.tracker = {
            "losses": [],
            "key_losses": [[] for _ in range(len(self.key_dim))],
            "reconstruction_errors": [[] for _ in range(len(self.key_dim))],
            "predictions": [],
            "key_predictions": [[] for _ in range(len(self.key_dim))],
            "labels": []
        }

    def update_tracker(self, output, predictions, key_predictions, labels):
        self.tracker["losses"].append(output["loss"])
        self.tracker["predictions"].extend(predictions.tolist())
        self.tracker["labels"].extend(labels.tolist())
        for i in range(len(self.key_dim)):
            self.tracker["key_losses"][i].append(output["key_losses"][i])
            self.tracker["reconstruction_errors"][i].append(output["reconstruction_errors"][i])
            self.tracker["key_predictions"][i].extend(key_predictions[i].tolist())

    def log(self):
        """Log results during training to console and optionally other outputs

        Parameters
        ---
        metrics: dict mapping metric names to their values
        """
        loss = np.mean(self.tracker["losses"])
        key_losses = [np.mean(key_losses) for key_losses in self.tracker["key_losses"]]
        reconstruction_errors = [np.mean(reconstruction_errors) for reconstruction_errors in self.tracker["reconstruction_errors"]]
        metrics = model_utils.calculate_metrics(self.tracker["predictions"], self.tracker["labels"])
        key_metrics = [
            model_utils.calculate_metrics(key_predictions, self.tracker["labels"])
            for key_predictions in self.tracker["key_predictions"]
        ]
        key_accuracy_str = [f'{km["accuracy"]:.4f}' for km in key_metrics]
        self.logger.info(
            f"Iteration {self.current_iter + 1} - Task = {self.metrics[-1]['task']} - Metrics: Loss = {loss:.4f}, "
            f"key loss = {[f'{key_loss:.4f}' for key_loss in key_losses]}, "
            f"reconstruction error = {[f'{reconstruction_error:.4f}' for reconstruction_error in reconstruction_errors]}, "
            f"accuracy = {metrics['accuracy']:.4f} - "
            f"key accuracy = {key_accuracy_str}"
        )
        if self.config.wandb:
            log = {
                "accuracy": metrics["accuracy"],
                "precision": metrics["precision"],
                "recall": metrics["recall"],
                "f1": metrics["f1"],
                "loss": loss,
                "examples_seen": self.examples_seen()
            }
            for i, dim in enumerate(self.key_dim):
                log[f"key_accuracy_encoder_{i}_dim_{dim}"] = key_metrics[i]["accuracy"]
                log[f"key_loss_encoder_{i}_dim_{dim}"] = key_losses[i]
                log[f"reconstruction_error_encoder_{i}_dim_{dim}"] = reconstruction_errors[i]
            wandb.log(log)
        self.reset_tracker()

    def examples_seen(self):
        return (self.current_iter + 1) * self.mini_batch_size

    def evaluate(self, dataloader, update_memory=False):
        self.set_eval()
        all_losses, all_predictions, all_labels = [], [], []

        self.logger.info("Starting evaluation...")
        for i, (text, labels, datasets) in enumerate(dataloader):
            labels = torch.tensor(labels).to(self.device)
            with torch.no_grad():
                output = self.forward(text, labels, update_memory=update_memory)
                logits = output["logits"]
                loss = self.loss_fn(logits, labels)
            loss = loss.item()
            pred = model_utils.make_prediction(logits.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())
            if i % 20 == 0:
                self.logger.info(f"Batch {i + 1}/{len(dataloader)} processed")

        results = model_utils.calculate_metrics(all_predictions, all_labels)
        self.logger.info("Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, "
                    "F1 score = {:.4f}".format(np.mean(all_losses), results["accuracy"],
                    results["precision"], results["recall"], results["f1"]))

        return results

    def model_state(self):
        state = {
            "encoder": self.encoder.state_dict(),
            "classifier": self.classifier.state_dict(),
        }
        for i in range(len(self.key_dim)):
            state[f"key_classifier_{i}"] = self.key_classifiers[i].state_dict()
            state[f"key_encoder_{i}"] = self.key_encoders[i].state_dict()
            state[f"key_decoder_{i}"] = self.key_decoders[i].state_dict()

        return state

    def load_model_state(self, checkpoint):
        self.encoder.load_state_dict(checkpoint["model_state"]["encoder"])
        self.classifier.load_state_dict(checkpoint["model_state"]["classifier"])
        for i in range(len(self.key_dim)):
            self.key_classifiers[i].load_state_dict(checkpoint["model_state"][f"key_classifier_{i}"])
            self.key_encoders[i].load_state_dict(checkpoint["model_state"][f"key_encoder_{i}"])
            self.key_decoders[i].load_state_dict(checkpoint["model_state"][f"key_decoder_{i}"])

    # def optimizer_state(self):
    #     return self.meta_optimizer.state_dict()

    # def load_optimizer_state(self, checkpoint):
    #     self.meta_optimizer.load_state_dict(checkpoint["optimizer"])

    def save_other_state_information(self, state):
        """Any learner specific state information is added here"""
        # state["memory"] = self.memory
        return state

    def load_other_state_information(self, checkpoint):
        """Any learner specific state information is loaded here"""
        pass
        # self.memory = checkpoint["memory"]

    def set_train(self):
        """Set underlying pytorch network to train mode.
        
        If learner has multiple models, this method should be overwritten.
        """
        self.encoder.train()

    def set_eval(self):
        """Set underlying pytorch network to evaluation mode.
        
        If learner has multiple models, this method should be overwritten.
        """
        self.encoder.eval()
Exemple #29
0
            input_ids, input_masks, input_segments, labels = data
            pred = net(
                input_ids=input_ids.long().cuda(),
                labels=None,
                attention_mask=input_masks.cuda(),
                token_type_ids=input_segments.cuda(),
            )[0]
            loss = loss_fn(pred, labels.cuda())
            # Before the backward pass, use the optimizer object to zero all of the
            # gradients for the Tensors it will update (which are the learnable weights
            # of the model)

            # Backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # Calling the step function on an Optimizer makes an update to its parameters
            optimizer.step()
            optimizer.zero_grad()

            avg_loss += loss.item()

        avg_val_loss = 0.0
        net.eval()

        valid_preds = np.zeros((len(valid_idx), 30))
        true_label = np.zeros((len(valid_idx), 30))
        for j, data in enumerate(val_loader):

            # get the inputs
            #             body, answer, title, category, host, labels = data
            #             content, labels = data
            input_ids, input_masks, input_segments, labels = data
Exemple #30
0
def train(args, train_dataset, model, tokenizer, teacher=None):
    """Train the model"""
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter(log_dir=args.output_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if "mask_score" in n and p.requires_grad
            ],
            "lr":
            args.mask_scores_learning_rate,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if "mask_score" not in n and p.requires_grad and not any(
                    nd in n for nd in no_decay)
            ],
            "lr":
            args.learning_rate,
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if "mask_score" not in n and p.requires_grad and any(
                    nd in n for nd in no_decay)
            ],
            "lr":
            args.learning_rate,
            "weight_decay":
            0.0,
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    # Distillation
    if teacher is not None:
        logger.info("  Training with distillation")

    global_step = 0
    # Global TopK
    if args.global_topk:
        threshold_mem = None
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to global_step of last saved checkpoint from model path
        try:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        except ValueError:
            global_step = 0
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            threshold, regu_lambda = schedule_threshold(
                step=global_step,
                total_step=t_total,
                warmup_steps=args.warmup_steps,
                final_threshold=args.final_threshold,
                initial_threshold=args.initial_threshold,
                final_warmup=args.final_warmup,
                initial_warmup=args.initial_warmup,
                final_lambda=args.final_lambda,
            )
            # Global TopK
            if args.global_topk:
                if threshold == 1.0:
                    threshold = -1e2  # Or an indefinitely low quantity
                else:
                    if (threshold_mem is None) or (
                            global_step % args.global_topk_frequency_compute
                            == 0):
                        # Sort all the values to get the global topK
                        concat = torch.cat([
                            param.view(-1)
                            for name, param in model.named_parameters()
                            if "mask_scores" in name
                        ])
                        n = concat.numel()
                        kth = max(n - (int(n * threshold) + 1), 1)
                        threshold_mem = concat.kthvalue(kth).values.item()
                        threshold = threshold_mem
                    else:
                        threshold = threshold_mem
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2] if args.model_type
                    in ["bert", "masked_bert", "xlnet", "albert"] else None
                )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids

            if "masked" in args.model_type:
                inputs["threshold"] = threshold

            outputs = model(**inputs)
            loss, logits_stu = outputs  # model outputs are always tuple in transformers (see doc)

            # Distillation loss
            if teacher is not None:
                if "token_type_ids" not in inputs:
                    inputs[
                        "token_type_ids"] = None if args.teacher_type == "xlm" else batch[
                            2]
                with torch.no_grad():
                    (logits_tea, ) = teacher(
                        input_ids=inputs["input_ids"],
                        token_type_ids=inputs["token_type_ids"],
                        attention_mask=inputs["attention_mask"],
                    )

                loss_logits = (nn.functional.kl_div(
                    input=nn.functional.log_softmax(
                        logits_stu / args.temperature, dim=-1),
                    target=nn.functional.softmax(logits_tea / args.temperature,
                                                 dim=-1),
                    reduction="batchmean",
                ) * (args.temperature**2))

                loss = args.alpha_distil * loss_logits + args.alpha_ce * loss

            # Regularization
            if args.regularization is not None:
                regu_ = regularization(model=model, mode=args.regularization)
                loss = loss + regu_lambda * regu_

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= args.gradient_accumulation_steps and
                (step + 1) == len(epoch_iterator)):
                if args.fp16:
                    nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                             args.max_grad_norm)
                else:
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             args.max_grad_norm)

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    tb_writer.add_scalar("threshold", threshold, global_step)
                    for name, param in model.named_parameters():
                        if not param.requires_grad:
                            continue
                        tb_writer.add_scalar("parameter_mean/" + name,
                                             param.data.mean(), global_step)
                        tb_writer.add_scalar("parameter_std/" + name,
                                             param.data.std(), global_step)
                        tb_writer.add_scalar("parameter_min/" + name,
                                             param.data.min(), global_step)
                        tb_writer.add_scalar("parameter_max/" + name,
                                             param.data.max(), global_step)
                        tb_writer.add_scalar("grad_mean/" + name,
                                             param.grad.data.mean(),
                                             global_step)
                        tb_writer.add_scalar("grad_std/" + name,
                                             param.grad.data.std(),
                                             global_step)
                        if args.regularization is not None and "mask_scores" in name:
                            if args.regularization == "l1":
                                perc = (torch.sigmoid(param) > threshold
                                        ).sum().item() / param.numel()
                            elif args.regularization == "l0":
                                perc = (torch.sigmoid(param - 2 / 3 *
                                                      np.log(0.1 / 1.1))
                                        ).sum().item() / param.numel()
                            tb_writer.add_scalar(
                                "retained_weights_perc/" + name, perc,
                                global_step)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            eval_key = "eval_{}".format(key)
                            logs[eval_key] = value

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()
                    logs["learning_rate"] = learning_rate_scalar[0]
                    if len(learning_rate_scalar) > 1:
                        for idx, lr in enumerate(learning_rate_scalar[1:]):
                            logs[f"learning_rate/{idx+1}"] = lr
                    logs["loss"] = loss_scalar
                    if teacher is not None:
                        logs["loss/distil"] = loss_logits.item()
                    if args.regularization is not None:
                        logs["loss/regularization"] = regu_.item()
                    if (teacher is not None) or (args.regularization
                                                 is not None):
                        if (teacher is not None) and (args.regularization
                                                      is not None):
                            logs["loss/instant_ce"] = (
                                loss.item() -
                                regu_lambda * logs["loss/regularization"] -
                                args.alpha_distil *
                                logs["loss/distil"]) / args.alpha_ce
                        elif teacher is not None:
                            logs["loss/instant_ce"] = (
                                loss.item() - args.alpha_distil *
                                logs["loss/distil"]) / args.alpha_ce
                        else:
                            logs["loss/instant_ce"] = loss.item(
                            ) - regu_lambda * logs["loss/regularization"]
                    logging_loss = tr_loss

                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    print(json.dumps({**logs, **{"step": global_step}}))

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step