Beispiel #1
0
 def _train_epoch(self, data_loader: DataLoader, optimizer: AdamW,
                  scheduler: LambdaLR, report_frequency: int):
     initial_time = time.time()
     total_train_loss = 0
     self.model.train()
     self.model.to(self.device)
     for step, batch in enumerate(data_loader):
         b_input_ids = batch[0].to(self.device)
         b_input_mask = batch[1].to(self.device)
         b_labels = batch[2].to(self.device)
         optimizer.zero_grad()
         loss, logits = self.model(b_input_ids,
                                   token_type_ids=None,
                                   attention_mask=b_input_mask,
                                   labels=b_labels)
         self._report_loss_and_time(step=step + 1,
                                    num_of_batches=len(data_loader),
                                    initial_time=initial_time,
                                    loss=loss,
                                    frequency=report_frequency)
         total_train_loss += loss.item()
         loss.backward()
         torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
         optimizer.step()
         scheduler.step()
     avg_train_loss = total_train_loss / len(data_loader)
     training_time = self._format_time_delta(initial_time)
     print(f"Average training loss: {avg_train_loss:.2f}")
     print(f"Training epoch took: {training_time}")
Beispiel #2
0
def main(args):
    bert = BertPretrain(args.model_name_or_path).cuda()
    optimizer = AdamW(bert.parameters(), lr=5e-5, eps=1e-8)

    # Configure tokenizer
    token_vocab_path = "bert-base-uncased-vocab.txt"
    tokenizer = BertWordPieceTokenizer(token_vocab_path, lowercase=True)

    dataset = MLMDataset(args.mlm_data_txt, tokenizer, 50)
    dataloader = DataLoader(
        dataset,
        batch_size=64,
        collate_fn=lambda samples: pad_sequence([s for s in samples], batch_first=True),
    )

    bert.train()
    for epoch in range(1, 1 + args.num_epochs):
        losses = []
        for batch in tqdm(dataloader, desc=f"Epoch {epoch:2d}", ncols=120):
            bert.zero_grad()

            inputs, labels = mask_tokens(batch, tokenizer)
            loss = bert(inputs.cuda(), labels.cuda())
            loss.backward()

            # torch.nn.utils.clip_grad_norm_(bert.parameters())

            optimizer.step()

            losses.append(loss.item())

        print(f"Epoch {epoch:2d} - loss: {np.mean(losses)}")

    os.makedirs(args.output_dir, exist_ok=True)
    bert.bert_model.save_pretrained(args.output_dir)
Beispiel #3
0
class Optim(object):
    def set_parameters(self, params):
        self.params = list(params)  # careful: params may be a generator
        if self.method == 'sgd':
            self.optimizer = optim.SGD(self.params, lr=self.lr)
        elif self.method == 'adagrad':
            self.optimizer = optim.Adagrad(self.params, lr=self.lr)
        elif self.method == 'adadelta':
            self.optimizer = optim.Adadelta(self.params, lr=self.lr)
        elif self.method == 'adam':
            self.optimizer = optim.Adam(self.params, lr=self.lr)
        elif self.method == 'bertadam':
            self.optimizer = AdamW(self.params, lr=self.lr)
        else:
            raise RuntimeError("Invalid optim method: " + self.method)

    def __init__(self,
                 method,
                 lr,
                 max_grad_norm,
                 lr_decay=1,
                 start_decay_at=None,
                 max_decay_times=2):
        self.last_score = None
        self.decay_times = 0
        self.max_decay_times = max_decay_times
        self.lr = float(lr)
        self.max_grad_norm = max_grad_norm
        self.method = method
        self.lr_decay = lr_decay
        self.start_decay_at = start_decay_at
        self.start_decay = False

    def step(self):
        # Compute gradients norm.
        if self.max_grad_norm:
            #梯度裁剪
            clip_grad_norm_(self.params, self.max_grad_norm)
        self.optimizer.step()

    def zero_grad(self):
        self.optimizer.zero_grad()

    #如果val perf没有改善,或者我们达到start_decay_at极限,则衰减学习率
    def updateLearningRate(self, score, epoch):
        if self.start_decay_at is not None and epoch >= self.start_decay_at:
            self.start_decay = True

        if self.start_decay:
            self.lr = self.lr * self.lr_decay
            print("Decaying learning rate to %g" % self.lr)

        self.last_score = score
        self.optimizer.param_groups[0]['lr'] = self.lr
    def train_func(self):
        # loss_fct = MarginRankingLoss(margin=1, reduction='mean')
        loss_fct = NLLLoss(reduction='mean')
        optimizer = AdamW(self.model.parameters(), self.args.learning_rate)
        step = 0
        # cos = nn.CosineSimilarity(dim=1)
        scheduler = optim.lr_scheduler.StepLR(
            optimizer,
            step_size=self.args.scheduler_step,
            gamma=self.args.scheduler_gamma)
        accumulate_step = 0

        for epoch in range(1, self.args.epoch + 1):
            for batch in self.loader:
                probs = self.get_probs(batch)
                batch_size = probs.size(0)

                true_idx = torch.zeros(batch_size, dtype=torch.long)
                if torch.cuda.is_available():
                    true_idx = true_idx.cuda()
                loss = loss_fct(probs, true_idx)
                loss.backward()

                self.writer.add_scalar('loss', loss, step)

                stop_scheduler_step = self.args.scheduler_step * 80

                if accumulate_step % self.args.gradient_accumulate_step == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    if self.args.scheduler_lr and step <= stop_scheduler_step:
                        scheduler.step()
                    accumulate_step = 0

                step += 1
                if step % self.args.save_model_step == 0:
                    model_basename = self.args.dest_base_dir + self.args.exp_name
                    model_basename += '_epoch_{}_step_{}'.format(epoch, step)
                    torch.save(self.model.state_dict(),
                               model_basename + '.model')
                    write_json(model_basename + '.json', vars(self.args))
                    ret = self.evaluate(model_basename, step)
                    self.writer.add_scalar('accuracy', ret, step)
                    # self.writer.add_scalar('recall', ret['recall'], step)
                    # self.writer.add_scalar('f1', ret['f1'], step)
                    msg_tmpl = 'step {} completed, accuracy {:.4f}'
                    self.logger.info(msg_tmpl.format(step, ret))
Beispiel #5
0
def train(model, train_iter, dev_iter, test_iter):
    starttime = time.time()
    model.train()
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=decay)
    total_batch = 0
    dev_best_loss = float("inf")
    last_improve = 0
    no_improve_flag = False
    model.train()
    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch + 1, num_epochs))
        for i, (X, y) in enumerate(train_iter):
            outputs = model(X)  # batch_size * num_classes
            model.zero_grad()
            loss = F.binary_cross_entropy(outputs, y)
            loss.backward()
            optimizer.step()
            if total_batch % 100 == 0:
                truelabels = torch.max(y.data, 1)[1].cpu()
                pred = torch.max(outputs, 1)[1].cpu()
                train_acc = metrics.accuracy_score(truelabels, pred)
                dev_acc, dev_loss = evaluate(model, dev_iter)
                if dev_loss < dev_best_loss:
                    dev_best_loss = dev_loss
                    torch.save(model.state_dict(), save_path)
                    improve = '*'
                    last_improve = total_batch
                else:
                    improve = ' '
                time_dif = get_time_dif(starttime)
                # 打印训练信息,id : >右对齐,n 宽度,.3 小数位数
                msg = 'Iter:{0:>6}, Train Loss:{1:>5.2}, Train Acc:{2:>6.2}, Val Loss:{3:>5.2}, val Acc :{4:>6.2%}, Time:{5} {6}'
                print(
                    msg.format(total_batch, loss.item(), train_acc, dev_loss,
                               dev_acc, time_dif, improve))
                model.train()
            total_batch += 1
            if total_batch - last_improve > early_stop_time:
                print(
                    "no improve after {} times, stop!".format(early_stop_time))
                no_improve_flag = True
                break
        if no_improve_flag:
            break
    test(model, test_iter)
class MTBModel(RelationExtractor):
    def __init__(self, config: dict):
        """
        Matching the Blanks Model.

        Args:
            config: configuration parameters
        """
        super().__init__()
        self.experiment_name = config.get("experiment_name")
        self.transformer = config.get("transformer")
        self.config = config
        self.data_loader = MTBPretrainDataLoader(self.config)
        self.train_len = len(self.data_loader.train_generator)
        logger.info("Loaded %d pre-training samples." % self.train_len)

        self.model = BertModel.from_pretrained(
            model_size=self.transformer,
            pretrained_model_name_or_path=self.transformer,
            force_download=False,
        )

        self.tokenizer = self.data_loader.tokenizer
        self.model.resize_token_embeddings(len(self.tokenizer))
        e1_id = self.tokenizer.convert_tokens_to_ids("[E1]")
        e2_id = self.tokenizer.convert_tokens_to_ids("[E2]")
        if e1_id == e2_id == 1:
            raise ValueError("e1_id == e2_id == 1")

        self.train_on_gpu = torch.cuda.is_available() and config.get(
            "use_gpu", True)
        if self.train_on_gpu:
            logger.info("Train on GPU")
            self.model.cuda()

        self.criterion = MTBLoss(lm_ignore_idx=self.tokenizer.pad_token_id, )
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.01,
            },
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=self.config.get("lr"))
        ovr_steps = (self.config.get("epochs") *
                     len(self.data_loader.train_generator) *
                     self.config.get("max_size") * 2 /
                     self.config.get("batch_size"))
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, ovr_steps // 10, ovr_steps)

        self._start_epoch = 0
        self._best_mtb_bce = 50
        self._train_loss = []
        self._train_lm_acc = []
        self._lm_acc = []
        self._mtb_bce = []
        self.checkpoint_dir = os.path.join("models", "MTB-pretraining",
                                           self.experiment_name,
                                           self.transformer)
        Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)

        self._batch_points_seen = 0
        self._points_seen = 0

    def load_best_model(self, checkpoint_dir: str):
        """
        Loads the current best model in the checkpoint directory.

        Args:
            checkpoint_dir: Checkpoint directory path
        """
        checkpoint = super().load_best_model(checkpoint_dir)
        return (
            checkpoint["epoch"],
            checkpoint["best_mtb_bce"],
            checkpoint["losses_per_epoch"],
            checkpoint["accuracy_per_epoch"],
            checkpoint["lm_acc"],
            checkpoint["blanks_mse"],
        )

    def train(self, **kwargs):
        """
        Runs the training.

        Arg:
            kwargs: Additional Keyword arguments
        """
        save_best_model_only = kwargs.get("save_best_model_only", False)
        results_path = os.path.join(
            "results",
            "MTB-pretraining",
            self.experiment_name,
            self.transformer,
        )
        best_model_path = os.path.join(self.checkpoint_dir,
                                       "best_model.pth.tar")
        resume = self.config.get("resume", False)
        if resume and os.path.exists(best_model_path):
            (
                self._start_epoch,
                self._best_mtb_bce,
                self._train_loss,
                self._train_lm_acc,
                self._lm_acc,
                self._mtb_bce,
            ) = self.load_best_model(self.checkpoint_dir)

        logger.info("Starting training process")
        update_size = len(self.data_loader.train_generator) // 100
        for epoch in range(self._start_epoch, self.config.get("epochs")):
            self._train_epoch(epoch, update_size, save_best_model_only)
            data = self._write_kpis(results_path)
            self._plot_results(data, results_path)
        logger.info("Finished Training.")
        return self.model

    def _plot_results(self, data, save_at):
        fig, ax = plt.subplots(figsize=(20, 20))
        sns.lineplot(x="Epoch", y="Train Loss", ax=ax, data=data, linewidth=4)
        ax.set_title("Training Loss")
        plt.savefig(
            os.path.join(save_at,
                         "train_loss_{0}.png".format(self.transformer)))
        plt.close(fig)

        fig, ax = plt.subplots(figsize=(20, 20))
        sns.lineplot(x="Epoch",
                     y="Val MTB Loss",
                     ax=ax,
                     data=data,
                     linewidth=4)
        ax.set_title("Val MTB Binary Cross Entropy")
        plt.savefig(
            os.path.join(save_at,
                         "val_mtb_bce_{0}.png".format(self.transformer)))
        plt.close(fig)

        tmp = data[["Epoch", "Train LM Accuracy",
                    "Val LM Accuracy"]].melt(id_vars="Epoch",
                                             var_name="Set",
                                             value_name="LM Accuracy")
        fig, ax = plt.subplots(figsize=(20, 20))
        sns.lineplot(
            x="Epoch",
            y="LM Accuracy",
            hue="Set",
            ax=ax,
            data=tmp,
            linewidth=4,
        )
        ax.set_title("LM Accuracy")
        plt.savefig(
            os.path.join(save_at, "lm_acc_{0}.png".format(self.transformer)))
        plt.close(fig)

    def _write_kpis(self, results_path):
        Path(results_path).mkdir(parents=True, exist_ok=True)
        data = pd.DataFrame({
            "Epoch": np.arange(len(self._train_loss)),
            "Train Loss": self._train_loss,
            "Train LM Accuracy": self._train_lm_acc,
            "Val LM Accuracy": self._lm_acc,
            "Val MTB Loss": self._mtb_bce,
        })
        data.to_csv(
            os.path.join(results_path,
                         "kpis_{0}.csv".format(self.transformer)),
            index=False,
        )
        return data

    def _train_epoch(self,
                     epoch,
                     update_size,
                     save_best_model_only: bool = False):
        start_time = super()._train_epoch(epoch)

        train_lm_acc, train_loss, train_mtb_bce = [], [], []

        for i, data in enumerate(tqdm(self.data_loader.train_generator)):
            sequence, masked_label, e1_e2_start, blank_labels = data
            if sequence.shape[1] > 70:
                continue
            res = self._train_on_batch(sequence, masked_label, e1_e2_start,
                                       blank_labels)
            if res[0]:
                train_loss.append(res[0])
                train_lm_acc.append(res[1])
                train_mtb_bce.append(res[2])
            if (i % update_size) == (update_size - 1):
                logger.info(
                    f"{i+1}/{self.train_len} pools: - " +
                    f"Train loss: {np.mean(train_loss)}, " +
                    f"Train LM accuracy: {np.mean(train_lm_acc)}, " +
                    f"Train MTB Binary Cross Entropy {np.mean(train_mtb_bce)}")

        self._train_loss.append(np.mean(train_loss))
        self._train_lm_acc.append(np.mean(train_lm_acc))

        self.on_epoch_end(epoch, self._mtb_bce, self._best_mtb_bce,
                          save_best_model_only)

        logger.info(
            f"Epoch finished, took {time.time() - start_time} seconds!")
        logger.info(f"Train Loss: {self._train_loss[-1]}!")
        logger.info(f"Train LM Accuracy: {self._train_lm_acc[-1]}!")
        logger.info(f"Validation LM Accuracy: {self._lm_acc[-1]}!")
        logger.info(
            f"Validation MTB Binary Cross Entropy: {self._mtb_bce[-1]}!")

    def on_epoch_end(self,
                     epoch,
                     benchmark,
                     baseline,
                     save_best_model_only: bool = False):
        """
        Function to run at the end of an epoch.

        Runs the evaluation method, increments the scheduler, sets a new baseline and appends the KPIS.ä

        Args:
            epoch: Current epoch
            benchmark: List of benchmark results
            baseline: Current baseline. Best model performance so far
            save_best_model_only: Whether to only save the best model so far
                or all of them
        """
        eval_result = super().on_epoch_end(epoch, benchmark, baseline)
        self._best_mtb_bce = (eval_result[1]
                              if eval_result[1] < self._best_mtb_bce else
                              self._best_mtb_bce)
        self._mtb_bce.append(eval_result[1])
        self._lm_acc.append(eval_result[0])
        super().save_on_epoch_end(self._mtb_bce, self._best_mtb_bce, epoch,
                                  save_best_model_only)

    def _train_on_batch(
        self,
        sequence,
        mskd_label,
        e1_e2_start,
        blank_labels,
    ):
        mskd_label = mskd_label[(mskd_label != self.tokenizer.pad_token_id)]
        if mskd_label.shape[0] == 0:
            return None, None, None
        if self.train_on_gpu:
            mskd_label = mskd_label.cuda()
        blanks_logits, lm_logits = self._get_logits(e1_e2_start, sequence)
        loss = self.criterion(
            lm_logits,
            blanks_logits,
            mskd_label,
            blank_labels,
        )
        loss_p = loss.item()
        loss = loss / self.config.get("batch_size")
        loss.backward()
        self._batch_points_seen += len(sequence)
        self._points_seen += len(sequence)
        if self._batch_points_seen > self.config.get("batch_size"):
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()
            self._batch_points_seen = 0
        train_metrics = self.calculate_metrics(
            lm_logits,
            blanks_logits,
            mskd_label,
            blank_labels,
        )
        return loss_p / len(sequence), train_metrics[0], train_metrics[1]

    def _save_model(self, path, epoch, best_model: bool = False):
        if best_model:
            model_path = os.path.join(path, "best_model.pth.tar")
        else:
            model_path = os.path.join(
                path, "checkpoint_epoch_{0}.pth.tar").format(epoch + 1)
        torch.save(
            {
                "epoch": epoch + 1,
                "state_dict": self.model.state_dict(),
                "tokenizer": self.tokenizer,
                "best_mtb_bce": self._best_mtb_bce,
                "optimizer": self.optimizer.state_dict(),
                "scheduler": self.scheduler.state_dict(),
                "losses_per_epoch": self._train_loss,
                "accuracy_per_epoch": self._train_lm_acc,
                "lm_acc": self._lm_acc,
                "blanks_mse": self._mtb_bce,
            },
            model_path,
        )

    def _get_logits(self, e1_e2_start, x):
        attention_mask = (x != self.tokenizer.pad_token_id).float()
        token_type_ids = torch.zeros((x.shape[0], x.shape[1])).long()
        if self.train_on_gpu:
            x = x.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
        blanks_logits, lm_logits = self.model(
            x,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            e1_e2_start=e1_e2_start,
        )
        lm_logits = lm_logits[(x == self.tokenizer.mask_token_id)]
        return blanks_logits, lm_logits

    def evaluate(self) -> tuple:
        """
        Run the validation generator and return performance metrics.
        """
        total_loss = []
        lm_acc = []
        blanks_mse = []

        self.model.eval()
        with torch.no_grad():
            for data in self.data_loader.validation_generator:
                (x, masked_label, e1_e2_start, blank_labels) = data
                masked_label = masked_label[(masked_label !=
                                             self.tokenizer.pad_token_id)]
                if masked_label.shape[0] == 0:
                    continue
                if self.train_on_gpu:
                    masked_label = masked_label.cuda()
                blanks_logits, lm_logits = self._get_logits(e1_e2_start, x)

                loss = self.criterion(
                    lm_logits,
                    blanks_logits,
                    masked_label,
                    blank_labels,
                )

                total_loss += loss.cpu().numpy()
                eval_result = self.calculate_metrics(lm_logits, blanks_logits,
                                                     masked_label,
                                                     blank_labels)
                lm_acc += [eval_result[0]]
                blanks_mse += [eval_result[1]]
        self.model.train()
        return (
            np.mean(lm_acc),
            sum(b for b in blanks_mse if b != 1) /
            len([b for b in blanks_mse if b != 1]),
        )

    def calculate_metrics(
        self,
        lm_logits,
        blanks_logits,
        masked_for_pred,
        blank_labels,
    ) -> tuple:
        """
        Calculates the performance metrics of the MTB model.

        Args:
            lm_logits: Language model Logits per word in vocabulary
            blanks_logits: Blank logits
            masked_for_pred: List of marked tokens
            blank_labels: Blank labels
        """
        lm_logits_pred_ids = torch.softmax(lm_logits, dim=-1).max(1)[1]
        lm_accuracy = ((lm_logits_pred_ids == masked_for_pred).sum().float() /
                       len(masked_for_pred)).item()

        pos_idxs = np.where(blank_labels == 1)[0]
        neg_idxs = np.where(blank_labels == 0)[0]

        if len(pos_idxs) > 1:
            # positives
            pos_logits = []
            for pos1, pos2 in combinations(pos_idxs, 2):
                pos_logits.append(
                    self._get_mtb_logits(blanks_logits[pos1, :],
                                         blanks_logits[pos2, :]))
            pos_logits = torch.stack(pos_logits, dim=0)
            pos_labels = [1.0 for _ in range(pos_logits.shape[0])]
        else:
            pos_logits, pos_labels = torch.FloatTensor([]), []
            if blanks_logits.is_cuda:
                pos_logits = pos_logits.cuda()

        # negatives
        neg_logits = []
        for pos_idx in pos_idxs:
            for neg_idx in neg_idxs:
                neg_logits.append(
                    MTBModel._get_mtb_logits(blanks_logits[pos_idx, :],
                                             blanks_logits[neg_idx, :]))
        neg_logits = torch.stack(neg_logits, dim=0)
        neg_labels = [0.0 for _ in range(neg_logits.shape[0])]

        blank_labels = torch.FloatTensor(pos_labels + neg_labels)
        blank_pred = torch.cat([pos_logits, neg_logits], dim=0)
        bce = nn.BCEWithLogitsLoss(reduction="mean")(
            blank_pred.detach().cpu(), blank_labels.detach().cpu())

        return lm_accuracy, bce.numpy()

    @classmethod
    def _get_mtb_logits(cls, f1_vec, f2_vec):
        factor = 1 / (torch.norm(f1_vec) * torch.norm(f2_vec))
        return factor * torch.dot(f1_vec, f2_vec)
    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size,
                                      shuffle=True,
                                      num_workers=2)

        global best_dev
        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        adversarial_loss = torch.nn.BCELoss().to(device)
        adversarial_loss_v2 = torch.nn.CrossEntropyLoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss().to(device)

        # Optimizers
        optimizer_G = torch.optim.Adam(G.parameters(),
                                       lr=args.G_lr)  # optimizer for generator
        optimizer_D = torch.optim.Adam(
            D.parameters(), lr=args.D_lr)  # optimizer for discriminator
        optimizer_E = AdamW(E.parameters(), args.bert_lr)
        optimizer_detector = torch.optim.Adam(detector.parameters(),
                                              lr=args.detector_lr)

        G_total_train_loss = []
        D_total_fake_loss = []
        D_total_real_loss = []
        FM_total_train_loss = []
        D_total_class_loss = []
        valid_detection_loss = []
        valid_oos_ind_precision = []
        valid_oos_ind_recall = []
        valid_oos_ind_f_score = []
        detector_total_train_loss = []

        all_features = []
        result = dict()

        for i in range(args.n_epoch):

            # Initialize model state
            G.train()
            D.train()
            E.train()
            detector.train()

            G_train_loss = 0
            D_fake_loss = 0
            D_real_loss = 0
            FM_train_loss = 0
            D_class_loss = 0
            detector_train_loss = 0

            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                ood_sample = (y == 0.0)
                # weight = torch.ones(len(ood_sample)).to(device) - ood_sample * args.beta
                # real_loss_func = torch.nn.BCELoss(weight=weight).to(device)

                # the label used to train generator and discriminator.
                valid_label = FloatTensor(batch, 1).fill_(1.0).detach()
                fake_label = FloatTensor(batch, 1).fill_(0.0).detach()

                optimizer_E.zero_grad()
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                # train D on real
                optimizer_D.zero_grad()
                real_f_vector, discriminator_output, classification_output = D(
                    real_feature, return_feature=True)
                # discriminator_output = discriminator_output.squeeze()
                real_loss = adversarial_loss(discriminator_output, valid_label)
                real_loss.backward(retain_graph=True)

                if args.do_vis:
                    all_features.append(real_f_vector.detach())

                # # train D on fake
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, 32, args.G_z_dim))).to(device)
                else:
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, args.G_z_dim))).to(device)
                fake_feature = G(z).detach()
                fake_discriminator_output = D.detect_only(fake_feature)
                fake_loss = adversarial_loss(fake_discriminator_output,
                                             fake_label)
                fake_loss.backward()
                optimizer_D.step()

                # if args.fine_tune:
                #     optimizer_E.step()

                # train G
                optimizer_G.zero_grad()
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, 32, args.G_z_dim))).to(device)
                else:
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, args.G_z_dim))).to(device)
                fake_f_vector, D_decision = D.detect_only(G(z),
                                                          return_feature=True)
                gd_loss = adversarial_loss(D_decision, valid_label)
                fm_loss = torch.abs(
                    torch.mean(real_f_vector.detach(), 0) -
                    torch.mean(fake_f_vector, 0)).mean()
                g_loss = gd_loss + 0 * fm_loss
                g_loss.backward()
                optimizer_G.step()

                optimizer_E.zero_grad()

                # train detector
                optimizer_detector.zero_grad()
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, 32, args.G_z_dim))).to(device)
                else:
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, args.G_z_dim))).to(device)
                fake_feature = G(z).detach()
                if args.loss == 'v1':
                    loss_fake = adversarial_loss(
                        detector(fake_feature),
                        fake_label)  # fake sample is ood
                else:
                    loss_fake = adversarial_loss_v2(
                        detector(fake_feature),
                        fake_label.long().squeeze())
                if args.loss == 'v1':
                    loss_real = adversarial_loss(detector(real_feature),
                                                 y.float())
                else:
                    loss_real = adversarial_loss_v2(detector(real_feature),
                                                    y.long())
                if args.detect_loss == 'v1':
                    detector_loss = args.beta * loss_fake + (
                        1 - args.beta) * loss_real
                else:
                    detector_loss = args.beta * loss_fake + loss_real
                    detector_loss = args.sigma * detector_loss
                detector_loss.backward()
                optimizer_detector.step()

                if args.fine_tune:
                    optimizer_E.step()

                global_step += 1

                D_fake_loss += fake_loss.detach()
                D_real_loss += real_loss.detach()
                G_train_loss += g_loss.detach() + fm_loss.detach()
                FM_train_loss += fm_loss.detach()
                detector_train_loss += detector_loss

            logger.info('[Epoch {}] Train: D_fake_loss: {}'.format(
                i, D_fake_loss / n_sample))
            logger.info('[Epoch {}] Train: D_real_loss: {}'.format(
                i, D_real_loss / n_sample))
            logger.info('[Epoch {}] Train: D_class_loss: {}'.format(
                i, D_class_loss / n_sample))
            logger.info('[Epoch {}] Train: G_train_loss: {}'.format(
                i, G_train_loss / n_sample))
            logger.info('[Epoch {}] Train: FM_train_loss: {}'.format(
                i, FM_train_loss / n_sample))
            logger.info('[Epoch {}] Train: detector_train_loss: {}'.format(
                i, detector_train_loss / n_sample))
            logger.info(
                '---------------------------------------------------------------------------'
            )

            D_total_fake_loss.append(D_fake_loss / n_sample)
            D_total_real_loss.append(D_real_loss / n_sample)
            D_total_class_loss.append(D_class_loss / n_sample)
            G_total_train_loss.append(G_train_loss / n_sample)
            FM_total_train_loss.append(FM_train_loss / n_sample)
            detector_total_train_loss.append(detector_train_loss / n_sample)

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_detection_loss.append(eval_result['detection_loss'])
                valid_oos_ind_precision.append(
                    eval_result['oos_ind_precision'])
                valid_oos_ind_recall.append(eval_result['oos_ind_recall'])
                valid_oos_ind_f_score.append(eval_result['oos_ind_f_score'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(-eval_result['eer'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_gan_model(D, G, config['gan_save_path'])
                    if args.fine_tune:
                        save_model(E,
                                   path=config['bert_save_path'],
                                   model_name='bert')

                logger.info(eval_result)
                logger.info('valid_eer: {}'.format(eval_result['eer']))
                logger.info('valid_oos_ind_precision: {}'.format(
                    eval_result['oos_ind_precision']))
                logger.info('valid_oos_ind_recall: {}'.format(
                    eval_result['oos_ind_recall']))
                logger.info('valid_oos_ind_f_score: {}'.format(
                    eval_result['oos_ind_f_score']))
                logger.info('valid_auc: {}'.format(eval_result['auc']))
                logger.info('valid_fpr95: {}'.format(
                    ErrorRateAt95Recall(eval_result['all_binary_y'],
                                        eval_result['y_score'])))

        if args.patience >= args.n_epoch:
            save_gan_model(D, G, config['gan_save_path'])
            if args.fine_tune:
                save_model(E, path=config['bert_save_path'], model_name='bert')

        freeze_data['D_total_fake_loss'] = D_total_fake_loss
        freeze_data['D_total_real_loss'] = D_total_real_loss
        freeze_data['D_total_class_loss'] = D_total_class_loss
        freeze_data['G_total_train_loss'] = G_total_train_loss
        freeze_data['FM_total_train_loss'] = FM_total_train_loss
        freeze_data['valid_real_loss'] = valid_detection_loss
        freeze_data['valid_oos_ind_precision'] = valid_oos_ind_precision
        freeze_data['valid_oos_ind_recall'] = valid_oos_ind_recall
        freeze_data['valid_oos_ind_f_score'] = valid_oos_ind_f_score

        best_dev = -early_stopping.best_score

        if args.do_vis:
            all_features = torch.cat(all_features, 0).cpu().numpy()
            result['all_features'] = all_features
        return result
Beispiel #8
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")

    parser.add_argument('--kshot',
                        type=float,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--use_mixup",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--beta_sampling_times',
                        type=int,
                        default=10,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    train_examples, dev_examples, test_examples, label_list = processor.load_FewRel_data(
        args.kshot)

    num_labels = len(label_list)
    print('num_labels:', num_labels, 'training size:', len(train_examples),
          'dev size:', len(dev_examples), 'test size:', len(test_examples))

    model = RobertaForSequenceClassification(num_labels)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)
        '''load dev set'''
        dev_features = convert_examples_to_features(
            dev_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        dev_all_input_ids = torch.tensor([f.input_ids for f in dev_features],
                                         dtype=torch.long)
        dev_all_input_mask = torch.tensor([f.input_mask for f in dev_features],
                                          dtype=torch.long)
        dev_all_segment_ids = torch.tensor(
            [f.segment_ids for f in dev_features], dtype=torch.long)
        dev_all_span_a_mask = torch.tensor(
            [f.span_a_mask for f in dev_features], dtype=torch.float)
        dev_all_span_b_mask = torch.tensor(
            [f.span_b_mask for f in dev_features], dtype=torch.float)

        dev_all_label_ids = torch.tensor([f.label_id for f in dev_features],
                                         dtype=torch.long)

        dev_data = TensorDataset(dev_all_input_ids, dev_all_input_mask,
                                 dev_all_segment_ids, dev_all_span_a_mask,
                                 dev_all_span_b_mask, dev_all_label_ids)
        dev_sampler = SequentialSampler(dev_data)
        dev_dataloader = DataLoader(dev_data,
                                    sampler=dev_sampler,
                                    batch_size=args.eval_batch_size)
        '''load test set'''
        test_features = convert_examples_to_features(
            test_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        eval_all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                          dtype=torch.long)
        eval_all_input_mask = torch.tensor(
            [f.input_mask for f in test_features], dtype=torch.long)
        eval_all_segment_ids = torch.tensor(
            [f.segment_ids for f in test_features], dtype=torch.long)
        eval_all_span_a_mask = torch.tensor(
            [f.span_a_mask for f in test_features], dtype=torch.float)
        eval_all_span_b_mask = torch.tensor(
            [f.span_b_mask for f in test_features], dtype=torch.float)
        # eval_all_pair_ids = [f.pair_id for f in test_features]
        eval_all_label_ids = torch.tensor([f.label_id for f in test_features],
                                          dtype=torch.long)

        eval_data = TensorDataset(eval_all_input_ids, eval_all_input_mask,
                                  eval_all_segment_ids, eval_all_span_a_mask,
                                  eval_all_span_b_mask, eval_all_label_ids)
        eval_sampler = SequentialSampler(eval_data)
        test_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        # logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_span_a_mask = torch.tensor([f.span_a_mask for f in train_features],
                                       dtype=torch.float)
        all_span_b_mask = torch.tensor([f.span_b_mask for f in train_features],
                                       dtype=torch.float)

        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_span_a_mask,
                                   all_span_b_mask, all_label_ids)
        train_sampler = RandomSampler(train_data)

        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        iter_co = 0
        final_test_performance = 0.0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, span_a_mask, span_b_mask, label_ids = batch

                #input_ids, input_mask, span_a_mask, span_b_mask
                logits = model(input_ids, input_mask, span_a_mask, span_b_mask)
                loss_fct = CrossEntropyLoss()

                loss = loss_fct(logits.view(-1, num_labels),
                                label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                iter_co += 1
                # if iter_co %20==0:
                if iter_co % len(train_dataloader) == 0:
                    # if iter_co % (len(train_dataloader)//2)==0:
                    '''
                    start evaluate on dev set after this epoch
                    '''
                    model.eval()

                    for idd, dev_or_test_dataloader in enumerate(
                        [dev_dataloader, test_dataloader]):

                        if idd == 0:
                            logger.info("***** Running dev *****")
                            logger.info("  Num examples = %d",
                                        len(dev_features))
                        else:
                            logger.info("***** Running test *****")
                            logger.info("  Num examples = %d",
                                        len(test_features))
                        # logger.info("  Batch size = %d", args.eval_batch_size)

                        eval_loss = 0
                        nb_eval_steps = 0
                        preds = []
                        gold_label_ids = []
                        # print('Evaluating...')
                        for input_ids, input_mask, segment_ids, span_a_mask, span_b_mask, label_ids in dev_or_test_dataloader:
                            input_ids = input_ids.to(device)
                            input_mask = input_mask.to(device)
                            segment_ids = segment_ids.to(device)
                            span_a_mask = span_a_mask.to(device)
                            span_b_mask = span_b_mask.to(device)
                            label_ids = label_ids.to(device)
                            gold_label_ids += list(
                                label_ids.detach().cpu().numpy())

                            with torch.no_grad():
                                logits = model(input_ids, input_mask,
                                               span_a_mask, span_b_mask)
                            if len(preds) == 0:
                                preds.append(logits.detach().cpu().numpy())
                            else:
                                preds[0] = np.append(
                                    preds[0],
                                    logits.detach().cpu().numpy(),
                                    axis=0)

                        preds = preds[0]

                        pred_probs = softmax(preds, axis=1)
                        pred_label_ids = list(np.argmax(pred_probs, axis=1))

                        assert len(pred_label_ids) == len(gold_label_ids)
                        hit_co = 0
                        for k in range(len(pred_label_ids)):
                            if pred_label_ids[k] == gold_label_ids[k]:
                                hit_co += 1
                        test_acc = hit_co / len(gold_label_ids)
                        f1 = test_acc

                        if idd == 0:  # this is dev
                            if f1 > max_dev_acc:
                                max_dev_acc = f1
                                print('\ndev acc :', f1, ' max_dev_acc:',
                                      max_dev_acc, '\n')
                                # '''store the model, because we can test after a max_dev acc reached'''
                                # model_to_save = (
                                #     model.module if hasattr(model, "module") else model
                                # )  # Take care of distributed/parallel training
                                # store_transformers_models(model_to_save, tokenizer, '/export/home/Dataset/BERT_pretrained_mine/event_2_nli', 'mnli_mypretrained_f1_'+str(max_dev_acc)+'.pt')

                            else:
                                print('\ndev acc :', f1, ' max_dev_acc:',
                                      max_dev_acc, '\n')
                                break
                        else:  # this is test
                            if f1 > max_test_acc:
                                max_test_acc = f1
                            final_test_performance = f1
                            print('\ntest acc:', f1, ' max_test_acc:',
                                  max_test_acc, '\n')
        print('final_test_f1:', final_test_performance)
def main():
    parser = ArgumentParser()
    parser.add_argument('--pregenerated_neg_data', type=Path, required=True)
    parser.add_argument('--pregenerated_data', type=Path, required=True)
    parser.add_argument('--output_dir', type=Path, required=True)
    parser.add_argument(
        "--bert_model",
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument(
        "--reduce_memory",
        action="store_true",
        help=
        "Store training data as on-disc memmaps to massively reduce memory usage"
    )

    parser.add_argument("--max_seq_len", default=512, type=int)

    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument("--epochs",
                        type=int,
                        default=3,
                        help="Number of epochs to train for")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--kr_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--kr_freq", default=0.7, type=float)
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--learning_rate",
                        default=1e-4,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    args = parser.parse_args()

    assert args.pregenerated_data.is_dir(), \
        "--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"

    samples_per_epoch = []
    for i in range(args.epochs):
        epoch_file = args.pregenerated_data / f"epoch_{i}.json"
        metrics_file = args.pregenerated_data / f"epoch_{i}_metrics.json"
        if epoch_file.is_file() and metrics_file.is_file():
            metrics = json.loads(metrics_file.read_text())
            samples_per_epoch.append(metrics['num_training_examples'])
        else:
            if i == 0:
                exit("No training data was found!")
            print(
                f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs})."
            )
            print(
                "This script will loop over the available data, but training diversity may be negatively impacted."
            )
            num_data_epochs = i
            break
    else:
        num_data_epochs = args.epochs

    if args.local_rank == -1 or args.no_cuda:
        print(torch.cuda.is_available())
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
        print(n_gpu)
        print("no gpu?")
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        print("GPU Device: ", device)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    logging.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # if n_gpu > 0:
    torch.cuda.manual_seed_all(args.seed)

    pt_output = Path(getenv('PT_OUTPUT_DIR', ''))
    args.output_dir = Path(os.path.join(pt_output, args.output_dir))

    if args.output_dir.is_dir() and list(args.output_dir.iterdir()):
        logging.warning(
            f"Output directory ({args.output_dir}) already exists and is not empty!"
        )
    args.output_dir.mkdir(parents=True, exist_ok=True)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    total_train_examples = 0
    for i in range(args.epochs):
        # The modulo takes into account the fact that we may loop over limited epochs of data
        total_train_examples += samples_per_epoch[i % len(samples_per_epoch)]

    num_train_optimization_steps = int(total_train_examples /
                                       args.train_batch_size /
                                       args.gradient_accumulation_steps)
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    # Prepare model
    config = BertConfig.from_pretrained(args.bert_model)
    # config.num_hidden_layers = args.num_layers
    model = FuckWrapper(config)
    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=num_train_optimization_steps)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    global_step = 0
    logging.info("***** Running training *****")
    logging.info(f"  Num examples = {total_train_examples}")
    logging.info("  Batch size = %d", args.train_batch_size)
    logging.info("  Num steps = %d", num_train_optimization_steps)
    model.train()

    before_train_path = Path(os.path.join(args.output_dir, "before_training"))
    print("Before training path: ", before_train_path)
    before_train_path.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(os.path.join(args.output_dir, "before_training"))
    tokenizer.save_pretrained(os.path.join(args.output_dir, "before_training"))

    neg_epoch_dataset = PregeneratedDataset(
        epoch=0,
        training_path=args.pregenerated_neg_data,
        tokenizer=tokenizer,
        num_data_epochs=num_data_epochs,
        reduce_memory=args.reduce_memory)
    if args.local_rank == -1:
        neg_train_sampler = RandomSampler(neg_epoch_dataset)
    else:
        neg_train_sampler = DistributedSampler(neg_epoch_dataset)

    neg_train_dataloader = DataLoader(neg_epoch_dataset,
                                      sampler=neg_train_sampler,
                                      batch_size=args.train_batch_size)

    def inf_train_gen():
        while True:
            for kr_step, kr_batch in enumerate(neg_train_dataloader):
                yield kr_step, kr_batch

    kr_gen = inf_train_gen()

    for epoch in range(args.epochs):
        epoch_dataset = PregeneratedDataset(
            epoch=epoch,
            training_path=args.pregenerated_data,
            tokenizer=tokenizer,
            num_data_epochs=num_data_epochs,
            reduce_memory=args.reduce_memory)
        if args.local_rank == -1:
            train_sampler = RandomSampler(epoch_dataset)
        else:
            train_sampler = DistributedSampler(epoch_dataset)

        train_dataloader = DataLoader(epoch_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0

        if n_gpu > 1 and args.local_rank == -1 or (n_gpu <= 1):
            logging.info("** ** * Saving fine-tuned model ** ** * ")
            model.save_pretrained(args.output_dir)
            tokenizer.save_pretrained(args.output_dir)

        with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar:
            for step, batch in enumerate(train_dataloader):
                model.train()

                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids = batch

                outputs = model(input_ids=input_ids,
                                attention_mask=input_mask,
                                token_type_ids=segment_ids,
                                masked_lm_labels=lm_label_ids,
                                negated=False)
                loss = outputs[0]
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)

                if args.local_rank == 0 or args.local_rank == -1:
                    nb_tr_steps += 1
                    pbar.update(1)
                    mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
                    pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    scheduler.step()  # Update learning rate schedule
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if random.random() > args.kr_freq:
                    kr_step, kr_batch = next(kr_gen)
                    kr_batch = tuple(t.to(device) for t in kr_batch)
                    input_ids, input_mask, segment_ids, lm_label_ids = kr_batch

                    outputs = model(input_ids=input_ids,
                                    attention_mask=input_mask,
                                    token_type_ids=segment_ids,
                                    masked_lm_labels=lm_label_ids,
                                    negated=True)
                    loss = outputs[0]
                    if n_gpu > 1:
                        loss = loss.mean()  # mean() to average on multi-gpu.
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    if args.fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)

                    tr_loss += loss.item()
                    nb_tr_examples += input_ids.size(0)
                    if args.local_rank == -1:
                        nb_tr_steps += 1
                        mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
                        pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        scheduler.step()  # Update learning rate schedule
                        optimizer.step()
                        optimizer.zero_grad()
                        global_step += 1

    # Save a trained model
    if n_gpu > 1 and args.local_rank == -1 or (n_gpu <= 1):
        logging.info("** ** * Saving fine-tuned model ** ** * ")
        model.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--DomainName",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_data_aug",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--meta_epochs',
                        type=int,
                        default=10,
                        help="random seed for initialization")
    parser.add_argument('--kshot',
                        type=int,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    args = parser.parse_args()


    processors = {
        "rte": RteProcessor
    }

    output_modes = {
        "rte": "classification"
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")


    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    # label_list = processor.get_labels() #["entailment", "neutral", "contradiction"]
    # label_list = ['How_do_I_create_a_profile_v4', 'Profile_Switch_v4', 'Deactivate_Active_Devices_v4', 'Ads_on_Hulu_v4', 'Watching_Hulu_with_Live_TV_v4', 'Hulu_Costs_and_Commitments_v4', 'offline_downloads_v4', 'womens_world_cup_v5', 'forgot_username_v4', 'confirm_account_cancellation_v4', 'Devices_to_Watch_HBO_on_v4', 'remove_add_on_v4', 'Internet_Speed_for_HD_and_4K_v4', 'roku_related_questions_v4', 'amazon_related_questions_v4', 'Clear_Browser_Cache_v4', 'ads_on_ad_free_plan_v4', 'inappropriate_ads_v4', 'itunes_related_questions_v4', 'Internet_Speed_Recommendations_v4', 'NBA_Basketball_v5', 'unexpected_charges_v4', 'change_billing_date_v4', 'NFL_on_Hulu_v5', 'How_to_delete_a_profile_v4', 'Devices_to_Watch_Hulu_on_v4', 'Manage_your_Hulu_subscription_v4', 'cancel_hulu_account_v4', 'disney_bundle_v4', 'payment_issues_v4', 'home_network_location_v4', 'Main_Menu_v4', 'Resetting_Hulu_Password_v4', 'Update_Payment_v4', 'I_need_general_troubleshooting_help_v4', 'What_is_Hulu_v4', 'sprint_related_questions_v4', 'Log_into_TV_with_activation_code_v4', 'Game_of_Thrones_v4', 'video_playback_issues_v4', 'How_to_edit_a_profile_v4', 'Watchlist_Remove_Video_v4', 'spotify_related_questions_v4', 'Deactivate_Login_Sessions_v4', 'Transfer_to_Agent_v4', 'Use_Hulu_Internationally_v4']

    meta_train_examples, meta_dev_examples, meta_test_examples, meta_label_list = load_CLINC150_without_specific_domain(args.DomainName)
    train_examples, dev_examples, eval_examples, finetune_label_list = load_CLINC150_with_specific_domain_sequence(args.DomainName, args.kshot, augment=args.do_data_aug)
    # oos_dev_examples, oos_test_examples = load_OOS()
    # dev_examples+=oos_dev_examples
    # eval_examples+=oos_test_examples

    eval_label_list = finetune_label_list#+['oos']
    label_list=finetune_label_list+meta_label_list#+['oos']
    assert len(label_list) ==  15*10
    num_labels = len(label_list)
    assert num_labels == 15*10


    model = RobertaForSequenceClassification(num_labels)


    tokenizer = RobertaTokenizer.from_pretrained(pretrain_model_dir, do_lower_case=args.do_lower_case)
    # tokenizer = BertTokenizer.from_pretrained(pretrain_model_dir, do_lower_case=args.do_lower_case)

    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    optimizer = AdamW(optimizer_grouped_parameters,
                             lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0
    if args.do_train:
        meta_train_features = convert_examples_to_features(
            meta_train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
            cls_token_at_end=False,#bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,#2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=True,#bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=False,#bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=0)#4 if args.model_type in ['xlnet'] else 0,)


        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
            cls_token_at_end=False,#bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,#2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=True,#bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=False,#bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=0)#4 if args.model_type in ['xlnet'] else 0,)

        '''load dev set'''
        # dev_examples = processor.get_RTE_as_dev('/export/home/Dataset/glue_data/RTE/dev.tsv')
        # dev_examples = get_data_hulu('dev')
        dev_features = convert_examples_to_features(
            dev_examples, eval_label_list, args.max_seq_length, tokenizer, output_mode,
            cls_token_at_end=False,#bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,#2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=True,#bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=False,#bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=0)#4 if args.model_type in ['xlnet'] else 0,)

        dev_all_input_ids = torch.tensor([f.input_ids for f in dev_features], dtype=torch.long)
        dev_all_input_mask = torch.tensor([f.input_mask for f in dev_features], dtype=torch.long)
        dev_all_segment_ids = torch.tensor([f.segment_ids for f in dev_features], dtype=torch.long)
        dev_all_label_ids = torch.tensor([f.label_id for f in dev_features], dtype=torch.long)

        dev_data = TensorDataset(dev_all_input_ids, dev_all_input_mask, dev_all_segment_ids, dev_all_label_ids)
        dev_sampler = SequentialSampler(dev_data)
        dev_dataloader = DataLoader(dev_data, sampler=dev_sampler, batch_size=args.eval_batch_size)


        '''load test set'''
        # eval_examples = processor.get_RTE_as_test('/export/home/Dataset/RTE/test_RTE_1235.txt')
        # eval_examples = get_data_hulu('test')
        eval_features = convert_examples_to_features(
            eval_examples, eval_label_list, args.max_seq_length, tokenizer, output_mode,
            cls_token_at_end=False,#bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,#2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=True,#bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=False,#bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=0)#4 if args.model_type in ['xlnet'] else 0,)

        eval_all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        eval_all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        eval_all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        eval_all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)

        eval_data = TensorDataset(eval_all_input_ids, eval_all_input_mask, eval_all_segment_ids, eval_all_label_ids)
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        all_input_ids = torch.tensor([f.input_ids for f in meta_train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in meta_train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in meta_train_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in meta_train_features], dtype=torch.long)

        meta_train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        meta_train_sampler = RandomSampler(meta_train_data)
        meta_train_dataloader = DataLoader(meta_train_data, sampler=meta_train_sampler, batch_size=args.train_batch_size*10)


        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
        '''support labeled examples in order, group in kshot size'''
        support_sampler = SequentialSampler(train_data)
        support_dataloader = DataLoader(train_data, sampler=support_sampler, batch_size=args.kshot)


        iter_co = 0
        max_dev_test = [0,0]
        fine_max_dev = False
        '''first train on meta_train tasks'''
        for meta_epoch_i in trange(args.meta_epochs, desc="metaEpoch"):
            for step, batch in enumerate(tqdm(meta_train_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                logits,_,_ = model(input_ids, input_mask, None, labels=None)
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            '''get class representation after each epoch of pretraining'''
            model.eval()
            last_reps_list = []
            for input_ids, input_mask, segment_ids, label_ids in support_dataloader:
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)
                # gold_label_ids+=list(label_ids.detach().cpu().numpy())

                with torch.no_grad():
                    logits, last_reps, bias = model(input_ids, input_mask, None, labels=None)
                last_reps_list.append(last_reps.mean(dim=0, keepdim=True)) #(1, 1024)
            class_reps_pretraining = torch.cat(last_reps_list, dim=0) #(15, 1024)

            '''
            start evaluate on dev set after this epoch
            '''
            for idd, dev_or_test_dataloader in enumerate([dev_dataloader, eval_dataloader]):
                if idd == 0:
                    logger.info("***** Running dev *****")
                    logger.info("  Num examples = %d", len(dev_examples))
                else:
                    logger.info("***** Running test *****")
                    logger.info("  Num examples = %d", len(eval_examples))
                # logger.info("  Batch size = %d", args.eval_batch_size)

                eval_loss = 0
                nb_eval_steps = 0
                preds = []
                gold_label_ids = []
                # print('Evaluating...')
                for input_ids, input_mask, segment_ids, label_ids in dev_or_test_dataloader:
                    input_ids = input_ids.to(device)
                    input_mask = input_mask.to(device)
                    segment_ids = segment_ids.to(device)
                    label_ids = label_ids.to(device)
                    gold_label_ids+=list(label_ids.detach().cpu().numpy())

                    with torch.no_grad():
                        logits_LR, reps_batch, _ = model(input_ids, input_mask, None, labels=None)
                    # logits = logits[0]

                    '''pretraining logits'''
                    raw_similarity_scores = torch.mm(reps_batch,torch.transpose(class_reps_pretraining, 0,1)) #(batch, 15)
                    # print('raw_similarity_scores shaoe:', raw_similarity_scores.shape)
                    # print('bias_finetune:', bias_finetune.shape)
                    biased_similarity_scores = raw_similarity_scores#+bias_finetune.view(-1, raw_similarity_scores.shape[1])
                    logits_pretrain = torch.max(biased_similarity_scores.view(args.eval_batch_size, -1, len(finetune_label_list)), dim=1)[0] #(batch, #class)
                    '''finetune logits'''
                    # raw_similarity_scores = torch.mm(reps_batch,torch.transpose(class_reps_finetune, 0,1)) #(batch, 15*history)
                    # biased_similarity_scores = raw_similarity_scores+bias_finetune.view(-1, raw_similarity_scores.shape[1])
                    # logits_finetune = torch.max(biased_similarity_scores.view(args.eval_batch_size, -1, len(finetune_label_list)), dim=1)[0] #(batch, #class)

                    logits = logits_pretrain#+logits_finetune
                    # logits = (1-0.9)*logits+0.9*logits_LR

                    if len(preds) == 0:
                        preds.append(logits.detach().cpu().numpy())
                    else:
                        preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)

                # eval_loss = eval_loss / nb_eval_steps
                preds = preds[0]
                pred_probs = softmax(preds,axis=1)
                pred_label_ids = list(np.argmax(pred_probs, axis=1))
                gold_label_ids = gold_label_ids
                assert len(pred_label_ids) == len(gold_label_ids)
                hit_co = 0

                for k in range(len(pred_label_ids)):
                    if pred_label_ids[k] == gold_label_ids[k]:
                        hit_co +=1
                test_acc = hit_co/len(gold_label_ids)

                if idd == 0: # this is dev
                    if test_acc > max_dev_acc:
                        max_dev_acc = test_acc
                        print('\ndev acc:', test_acc, ' max_dev_acc:', max_dev_acc, '\n')
                        fine_max_dev=True
                        max_dev_test[0] = round(max_dev_acc*100, 2)
                    else:
                        print('\ndev acc:', test_acc, ' max_dev_acc:', max_dev_acc, '\n')
                        break
                else: # this is test
                    if test_acc > max_test_acc:
                        max_test_acc = test_acc
                    if fine_max_dev:
                        max_dev_test[1] = round(test_acc*100,2)
                        fine_max_dev = False
                    print('\ttest acc:', test_acc, ' max_test_acc:', max_test_acc, '\n')


        print('final:', str(max_dev_test[0])+'/'+str(max_dev_test[1]), '\n')
Beispiel #11
0
def main():
    parser = argparse.ArgumentParser()

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")

    parser.add_argument('--kshot',
                        type=int,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--target_train_batch_size",
                        default=2,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    scitail_path = '/export/home/Dataset/SciTailV1/tsv_format/'
    target_kshot_entail_examples, target_kshot_nonentail_examples = get_SciTail_as_train_k_shot(
        scitail_path + 'scitail_1.0_train.tsv', args.kshot,
        args.seed)  #train_pu_half_v1.txt
    target_dev_examples, target_test_examples = get_SciTail_dev_and_test(
        scitail_path + 'scitail_1.0_dev.tsv',
        scitail_path + 'scitail_1.0_test.tsv')

    system_seed = 42
    random.seed(system_seed)
    np.random.seed(system_seed)
    torch.manual_seed(system_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(system_seed)

    source_kshot_size = 10  # if args.kshot>10 else 10 if max(10, args.kshot)
    source_kshot_entail, source_kshot_neural, source_kshot_contra, source_remaining_examples = get_MNLI_train(
        '/export/home/Dataset/glue_data/MNLI/train.tsv', source_kshot_size)
    source_examples = source_kshot_entail + source_kshot_neural + source_kshot_contra + source_remaining_examples
    target_label_list = ["entails", "neutral"]
    source_label_list = ["entailment", "neutral", "contradiction"]
    source_num_labels = len(source_label_list)
    target_num_labels = len(target_label_list)
    print('training size:', len(source_examples), 'dev size:',
          len(target_dev_examples), 'test size:', len(target_test_examples))

    roberta_model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    roberta_model.load_state_dict(torch.load(
        '/export/home/Dataset/BERT_pretrained_mine/MNLI_pretrained/_acc_0.9040886899918633.pt'
    ),
                                  strict=False)
    roberta_model.to(device)
    roberta_model.eval()

    protonet = PrototypeNet(bert_hidden_dim)
    protonet.to(device)

    param_optimizer = list(protonet.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0

    retrieve_batch_size = 5

    source_kshot_entail_dataloader = examples_to_features(
        source_kshot_entail,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_kshot_neural_dataloader = examples_to_features(
        source_kshot_neural,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_kshot_contra_dataloader = examples_to_features(
        source_kshot_contra,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_remain_ex_dataloader = examples_to_features(
        source_remaining_examples,
        source_label_list,
        args,
        tokenizer,
        args.train_batch_size,
        "classification",
        dataloader_mode='random')

    target_kshot_entail_dataloader = examples_to_features(
        target_kshot_entail_examples,
        target_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    target_kshot_nonentail_dataloader = examples_to_features(
        target_kshot_nonentail_examples,
        target_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    target_dev_dataloader = examples_to_features(target_dev_examples,
                                                 target_label_list,
                                                 args,
                                                 tokenizer,
                                                 args.eval_batch_size,
                                                 "classification",
                                                 dataloader_mode='sequential')
    target_test_dataloader = examples_to_features(target_test_examples,
                                                  target_label_list,
                                                  args,
                                                  tokenizer,
                                                  args.eval_batch_size,
                                                  "classification",
                                                  dataloader_mode='sequential')
    '''
    retrieve rep for support examples in MNLI
    '''
    kshot_entail_reps = []
    for entail_batch in source_kshot_entail_dataloader:
        entail_batch = tuple(t.to(device) for t in entail_batch)
        input_ids, input_mask, segment_ids, label_ids = entail_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_entail, _ = roberta_model(input_ids, input_mask)
        kshot_entail_reps.append(last_hidden_entail)
    kshot_entail_rep = torch.mean(torch.cat(kshot_entail_reps, dim=0),
                                  dim=0,
                                  keepdim=True)
    kshot_neural_reps = []
    for neural_batch in source_kshot_neural_dataloader:
        neural_batch = tuple(t.to(device) for t in neural_batch)
        input_ids, input_mask, segment_ids, label_ids = neural_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_neural, _ = roberta_model(input_ids, input_mask)
        kshot_neural_reps.append(last_hidden_neural)
    kshot_neural_rep = torch.mean(torch.cat(kshot_neural_reps, dim=0),
                                  dim=0,
                                  keepdim=True)
    kshot_contra_reps = []
    for contra_batch in source_kshot_contra_dataloader:
        contra_batch = tuple(t.to(device) for t in contra_batch)
        input_ids, input_mask, segment_ids, label_ids = contra_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_contra, _ = roberta_model(input_ids, input_mask)
        kshot_contra_reps.append(last_hidden_contra)
    kshot_contra_rep = torch.mean(torch.cat(kshot_contra_reps, dim=0),
                                  dim=0,
                                  keepdim=True)

    source_class_prototype_reps = torch.cat(
        [kshot_entail_rep, kshot_neural_rep, kshot_contra_rep],
        dim=0)  #(3, hidden)
    '''first get representations for support examples in target'''
    kshot_entail_reps = []
    for entail_batch in target_kshot_entail_dataloader:
        entail_batch = tuple(t.to(device) for t in entail_batch)
        input_ids, input_mask, segment_ids, label_ids = entail_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_entail, _ = roberta_model(input_ids, input_mask)
        kshot_entail_reps.append(last_hidden_entail)
    all_kshot_entail_reps = torch.cat(kshot_entail_reps, dim=0)
    kshot_entail_rep = torch.mean(all_kshot_entail_reps, dim=0, keepdim=True)
    kshot_nonentail_reps = []
    for nonentail_batch in target_kshot_nonentail_dataloader:
        nonentail_batch = tuple(t.to(device) for t in nonentail_batch)
        input_ids, input_mask, segment_ids, label_ids = nonentail_batch
        roberta_model.eval()
        with torch.no_grad():
            last_hidden_nonentail, _ = roberta_model(input_ids, input_mask)
        kshot_nonentail_reps.append(last_hidden_nonentail)
    all_kshot_neural_reps = torch.cat(kshot_nonentail_reps, dim=0)
    kshot_nonentail_rep = torch.mean(all_kshot_neural_reps,
                                     dim=0,
                                     keepdim=True)
    target_class_prototype_reps = torch.cat(
        [kshot_entail_rep, kshot_nonentail_rep, kshot_nonentail_rep],
        dim=0)  #(3, hidden)

    class_prototype_reps = torch.cat(
        [source_class_prototype_reps, target_class_prototype_reps],
        dim=0)  #(6, hidden)
    '''starting to train'''
    iter_co = 0
    tr_loss = 0
    source_loss = 0
    target_loss = 0
    final_test_performance = 0.0
    for _ in trange(int(args.num_train_epochs), desc="Epoch"):

        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(
                tqdm(source_remain_ex_dataloader, desc="Iteration")):
            protonet.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, source_label_ids_batch = batch

            roberta_model.eval()
            with torch.no_grad():
                source_last_hidden_batch, _ = roberta_model(
                    input_ids, input_mask)
            '''forward to model'''
            target_batch_size = args.target_train_batch_size  #10*3
            target_batch_size_entail = target_batch_size  #random.randrange(5)+1
            target_batch_size_neural = target_batch_size  #random.randrange(5)+1

            selected_target_entail_rep = all_kshot_entail_reps[torch.randperm(
                all_kshot_entail_reps.shape[0])[:target_batch_size_entail]]
            selected_target_neural_rep = all_kshot_neural_reps[torch.randperm(
                all_kshot_neural_reps.shape[0])[:target_batch_size_neural]]
            target_last_hidden_batch = torch.cat(
                [selected_target_entail_rep, selected_target_neural_rep])

            last_hidden_batch = torch.cat(
                [source_last_hidden_batch, target_last_hidden_batch],
                dim=0)  #(train_batch_size+10*2)
            batch_logits = protonet(class_prototype_reps, last_hidden_batch)
            '''source side loss'''
            # loss_fct = CrossEntropyLoss(reduction='none')
            loss_fct = CrossEntropyLoss()
            source_loss_list = loss_fct(
                batch_logits[:source_last_hidden_batch.shape[0]].view(
                    -1, source_num_labels), source_label_ids_batch.view(-1))
            '''target side loss'''
            target_label_ids_batch = torch.tensor(
                [0] * selected_target_entail_rep.shape[0] +
                [1] * selected_target_neural_rep.shape[0],
                dtype=torch.long)
            target_batch_logits = batch_logits[-target_last_hidden_batch.
                                               shape[0]:]
            target_loss_list = loss_by_logits_and_2way_labels(
                target_batch_logits, target_label_ids_batch.view(-1), device)

            loss = source_loss_list + target_loss_list  #torch.mean(torch.cat([source_loss_list, target_loss_list]))
            source_loss += source_loss_list
            target_loss += target_loss_list
            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            global_step += 1
            iter_co += 1
            '''print loss'''
            # if iter_co %5==0:
            #     print('iter_co:', iter_co, ' mean loss', tr_loss/iter_co)
            #     print('source_loss_list:', source_loss/iter_co, ' target_loss_list: ', target_loss/iter_co)
            if iter_co % 1 == 0:
                # if iter_co % len(source_remain_ex_dataloader)==0:
                '''
                start evaluate on dev set after this epoch
                '''
                protonet.eval()

                for idd, dev_or_test_dataloader in enumerate(
                    [target_dev_dataloader, target_test_dataloader]):

                    eval_loss = 0
                    nb_eval_steps = 0
                    preds = []
                    gold_label_ids = []
                    # print('Evaluating...')
                    for input_ids, input_mask, segment_ids, label_ids in dev_or_test_dataloader:
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        label_ids = label_ids.to(device)
                        gold_label_ids += list(
                            label_ids.detach().cpu().numpy())
                        roberta_model.eval()
                        with torch.no_grad():
                            last_hidden_target_batch, logits_from_source = roberta_model(
                                input_ids, input_mask)

                        with torch.no_grad():
                            logits = protonet(class_prototype_reps,
                                              last_hidden_target_batch)
                        '''combine with logits from source domain'''
                        # print('logits:', logits)
                        # print('logits_from_source:', logits_from_source)
                        # weight = 0.9
                        # logits = weight*logits+(1.0-weight)*torch.sigmoid(logits_from_source)
                        if len(preds) == 0:
                            preds.append(logits.detach().cpu().numpy())
                        else:
                            preds[0] = np.append(preds[0],
                                                 logits.detach().cpu().numpy(),
                                                 axis=0)

                    preds = preds[0]

                    pred_probs = softmax(preds, axis=1)
                    pred_label_ids_3way = list(np.argmax(pred_probs, axis=1))
                    '''change from 3-way to 2-way'''
                    pred_label_ids = []
                    for pred_id in pred_label_ids_3way:
                        if pred_id != 0:
                            pred_label_ids.append(1)
                        else:
                            pred_label_ids.append(0)

                    gold_label_ids = gold_label_ids
                    assert len(pred_label_ids) == len(gold_label_ids)
                    hit_co = 0
                    for k in range(len(pred_label_ids)):
                        if pred_label_ids[k] == gold_label_ids[k]:
                            hit_co += 1
                    test_acc = hit_co / len(gold_label_ids)

                    if idd == 0:  # this is dev
                        if test_acc > max_dev_acc:
                            max_dev_acc = test_acc
                            print('\niter', iter_co, '\tdev acc:', test_acc,
                                  ' max_dev_acc:', max_dev_acc, '\n')

                        else:
                            print('\niter', iter_co, '\tdev acc:', test_acc,
                                  ' max_dev_acc:', max_dev_acc, '\n')
                            break
                    else:  # this is test
                        if test_acc > max_test_acc:
                            max_test_acc = test_acc

                        final_test_performance = test_acc
                        print('\niter', iter_co, '\ttest acc:', test_acc,
                              ' max_test_acc:', max_test_acc, '\n')
            # if iter_co == 500:#3000:
            #     break
    print('final_test_performance:', final_test_performance)
Beispiel #12
0
def main():
    parser = argparse.ArgumentParser()


    ## Other parameters
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")

    parser.add_argument('--kshot',
                        type=int,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")


    args = parser.parse_args()



    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


    mctest_path = '/export/home/Dataset/MCTest/Statements/'
    target_kshot_entail_examples, target_kshot_nonentail_examples = get_MCTest_train(mctest_path+'mc500.train.statements.pairs', args.kshot) #train_pu_half_v1.txt
    target_dev_examples, target_test_examples = get_MCTest_dev_and_test(mctest_path+'mc500.dev.statements.pairs', mctest_path+'mc500.test.statements.pairs')


    source_kshot_entail, source_kshot_neural, source_kshot_contra, source_remaining_examples = get_MNLI_train('/export/home/Dataset/glue_data/MNLI/train.tsv', args.kshot)
    source_examples = source_kshot_entail+ source_kshot_neural+ source_kshot_contra+ source_remaining_examples
    target_label_list = ["ENTAILMENT", "UNKNOWN"]
    source_label_list = ["entailment", "neutral", "contradiction"]
    source_num_labels = len(source_label_list)
    target_num_labels = len(target_label_list)
    print('training size:', len(source_examples), 'dev size:', len(target_dev_examples), 'test size:', len(target_test_examples))

    num_train_optimization_steps = None
    num_train_optimization_steps = int(
        len(source_remaining_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

    roberta_model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(pretrain_model_dir, do_lower_case=args.do_lower_case)
    roberta_model.load_state_dict(torch.load('/export/home/Dataset/BERT_pretrained_mine/MNLI_pretrained/_acc_0.9040886899918633.pt'),strict=False)
    roberta_model.to(device)
    roberta_model.eval()

    protonet = PrototypeNet(bert_hidden_dim)
    protonet.to(device)

    param_optimizer = list(protonet.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    optimizer = AdamW(optimizer_grouped_parameters,
                             lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0

    retrieve_batch_size = 5

    source_kshot_entail_dataloader = examples_to_features(source_kshot_entail, source_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    source_kshot_neural_dataloader = examples_to_features(source_kshot_neural, source_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    source_kshot_contra_dataloader = examples_to_features(source_kshot_contra, source_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    source_remain_ex_dataloader = examples_to_features(source_remaining_examples, source_label_list, args, tokenizer, args.train_batch_size, "classification", dataloader_mode='random')

    target_kshot_entail_dataloader = examples_to_features(target_kshot_entail_examples, target_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    target_kshot_nonentail_dataloader = examples_to_features(target_kshot_nonentail_examples, target_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
    target_dev_dataloader = examples_to_features(target_dev_examples, target_label_list, args, tokenizer, args.eval_batch_size, "classification", dataloader_mode='sequential')
    target_test_dataloader = examples_to_features(target_test_examples, target_label_list, args, tokenizer, args.eval_batch_size, "classification", dataloader_mode='sequential')

    '''starting to train'''
    iter_co = 0
    final_test_performance = 0.0
    for _ in trange(int(args.num_train_epochs), desc="Epoch"):
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(tqdm(source_remain_ex_dataloader, desc="Iteration")):
            protonet.train()
            batch = tuple(t.to(device) for t in batch)
            _, input_ids, input_mask, segment_ids, label_ids_batch = batch

            roberta_model.eval()
            with torch.no_grad():
                last_hidden_batch, _ = roberta_model(input_ids, input_mask)
            '''
            retrieve rep for support examples
            '''
            kshot_entail_reps = []
            for entail_batch in source_kshot_entail_dataloader:
                entail_batch = tuple(t.to(device) for t in entail_batch)
                _, input_ids, input_mask, segment_ids, label_ids = entail_batch
                roberta_model.eval()
                with torch.no_grad():
                    last_hidden_entail, _ = roberta_model(input_ids, input_mask)
                kshot_entail_reps.append(last_hidden_entail)
            kshot_entail_rep = torch.mean(torch.cat(kshot_entail_reps, dim=0), dim=0, keepdim=True)
            kshot_neural_reps = []
            for neural_batch in source_kshot_neural_dataloader:
                neural_batch = tuple(t.to(device) for t in neural_batch)
                _, input_ids, input_mask, segment_ids, label_ids = neural_batch
                roberta_model.eval()
                with torch.no_grad():
                    last_hidden_neural, _ = roberta_model(input_ids, input_mask)
                kshot_neural_reps.append(last_hidden_neural)
            kshot_neural_rep = torch.mean(torch.cat(kshot_neural_reps, dim=0), dim=0, keepdim=True)
            kshot_contra_reps = []
            for contra_batch in source_kshot_contra_dataloader:
                contra_batch = tuple(t.to(device) for t in contra_batch)
                _, input_ids, input_mask, segment_ids, label_ids = contra_batch
                roberta_model.eval()
                with torch.no_grad():
                    last_hidden_contra, _ = roberta_model(input_ids, input_mask)
                kshot_contra_reps.append(last_hidden_contra)
            kshot_contra_rep = torch.mean(torch.cat(kshot_contra_reps, dim=0), dim=0, keepdim=True)

            class_prototype_reps = torch.cat([kshot_entail_rep, kshot_neural_rep, kshot_contra_rep], dim=0) #(3, hidden)

            '''forward to model'''
            batch_logits = protonet(class_prototype_reps, last_hidden_batch)

            loss_fct = CrossEntropyLoss()

            loss = loss_fct(batch_logits.view(-1, source_num_labels), label_ids_batch.view(-1))

            if n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            iter_co+=1
            # if iter_co %20==0:
            if iter_co % len(source_remain_ex_dataloader)==0:
                '''
                start evaluate on dev set after this epoch
                '''
                protonet.eval()
                '''first get representations for support examples'''
                kshot_entail_reps = []
                for entail_batch in target_kshot_entail_dataloader:
                    entail_batch = tuple(t.to(device) for t in entail_batch)
                    _, input_ids, input_mask, segment_ids, label_ids = entail_batch
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_entail, _ = roberta_model(input_ids, input_mask)
                    kshot_entail_reps.append(last_hidden_entail)
                kshot_entail_rep = torch.mean(torch.cat(kshot_entail_reps, dim=0), dim=0, keepdim=True)
                kshot_nonentail_reps = []
                for nonentail_batch in target_kshot_nonentail_dataloader:
                    nonentail_batch = tuple(t.to(device) for t in nonentail_batch)
                    _, input_ids, input_mask, segment_ids, label_ids = nonentail_batch
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_nonentail, _ = roberta_model(input_ids, input_mask)
                    kshot_nonentail_reps.append(last_hidden_nonentail)
                kshot_nonentail_rep = torch.mean(torch.cat(kshot_nonentail_reps, dim=0), dim=0, keepdim=True)
                target_class_prototype_reps = torch.cat([kshot_entail_rep, kshot_nonentail_rep], dim=0) #(2, hidden)

                for idd, dev_or_test_dataloader in enumerate([target_dev_dataloader, target_test_dataloader]):

                    if idd == 0:
                        logger.info("***** Running dev *****")
                        logger.info("  Num examples = %d", len(target_dev_examples))
                    else:
                        logger.info("***** Running test *****")
                        logger.info("  Num examples = %d", len(target_test_examples))


                    eval_loss = 0
                    nb_eval_steps = 0
                    preds = []
                    gold_label_ids = []
                    gold_pair_ids = []
                    for input_pair_ids, input_ids, input_mask, segment_ids, label_ids in dev_or_test_dataloader:
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        gold_pair_ids+= list(input_pair_ids.numpy())
                        label_ids = label_ids.to(device)
                        gold_label_ids+=list(label_ids.detach().cpu().numpy())
                        roberta_model.eval()
                        with torch.no_grad():
                            last_hidden_target_batch, _ = roberta_model(input_ids, input_mask)

                        with torch.no_grad():
                            logits = protonet(target_class_prototype_reps, last_hidden_target_batch)
                        if len(preds) == 0:
                            preds.append(logits.detach().cpu().numpy())
                        else:
                            preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)
                    preds = preds[0]
                    pred_probs = list(softmax(preds,axis=1)[:,0]) #entail prob

                    assert len(gold_pair_ids) == len(pred_probs)
                    assert len(gold_pair_ids) == len(gold_label_ids)

                    pairID_2_predgoldlist = {}
                    for pair_id, prob, gold_id in zip(gold_pair_ids, pred_probs, gold_label_ids):
                        predgoldlist = pairID_2_predgoldlist.get(pair_id)
                        if predgoldlist is None:
                            predgoldlist = []
                        predgoldlist.append((prob, gold_id))
                        pairID_2_predgoldlist[pair_id] = predgoldlist
                    total_size = len(pairID_2_predgoldlist)
                    hit_size = 0
                    for pair_id, predgoldlist in pairID_2_predgoldlist.items():
                        predgoldlist.sort(key=lambda x:x[0]) #sort by prob
                        assert len(predgoldlist) == 4
                        if predgoldlist[-1][1] == 0:
                            hit_size+=1
                    test_acc= hit_size/total_size

                    if idd == 0: # this is dev
                        if test_acc > max_dev_acc:
                            max_dev_acc = test_acc
                            print('\ndev acc:', test_acc, ' max_dev_acc:', max_dev_acc, '\n')

                        else:
                            print('\ndev acc:', test_acc, ' max_dev_acc:', max_dev_acc, '\n')
                            break
                    else: # this is test
                        if test_acc > max_test_acc:
                            max_test_acc = test_acc

                        final_test_performance = test_acc
                        print('\n\t\t test acc:', test_acc, ' max_test_acc:', max_test_acc, '\n')

    print('final_test_performance:', final_test_performance)
Beispiel #13
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default='bert-base-uncased',
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        '--task',
        type=str,
        default=None,
        required=True,
        help="Task code in {hotpot_open, hotpot_distractor, squad, nq}")

    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=378,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=1,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=5,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam. (def: 5e-5)")
    parser.add_argument("--num_train_epochs",
                        default=5.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument('--local_rank', default=-1, type=int)

    # RNN graph retriever-specific parameters
    parser.add_argument("--example_limit", default=None, type=int)

    parser.add_argument("--max_para_num", default=10, type=int)
    parser.add_argument(
        "--neg_chunk",
        default=8,
        type=int,
        help="The chunk size of negative examples during training (to "
        "reduce GPU memory consumption with negative sampling)")
    parser.add_argument(
        "--eval_chunk",
        default=100000,
        type=int,
        help=
        "The chunk size of evaluation examples (to reduce RAM consumption during evaluation)"
    )
    parser.add_argument(
        "--split_chunk",
        default=300,
        type=int,
        help=
        "The chunk size of BERT encoding during inference (to reduce GPU memory consumption)"
    )

    parser.add_argument('--train_file_path',
                        type=str,
                        default=None,
                        help="File path to the training data")
    parser.add_argument('--dev_file_path',
                        type=str,
                        default=None,
                        help="File path to the eval data")

    parser.add_argument('--beam', type=int, default=1, help="Beam size")
    parser.add_argument('--min_select_num',
                        type=int,
                        default=1,
                        help="Minimum number of selected paragraphs")
    parser.add_argument('--max_select_num',
                        type=int,
                        default=3,
                        help="Maximum number of selected paragraphs")
    parser.add_argument(
        "--use_redundant",
        action='store_true',
        help="Whether to use simulated seqs (only for training)")
    parser.add_argument(
        "--use_multiple_redundant",
        action='store_true',
        help="Whether to use multiple simulated seqs (only for training)")
    parser.add_argument(
        '--max_redundant_num',
        type=int,
        default=100000,
        help=
        "Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)"
    )
    parser.add_argument(
        "--no_links",
        action='store_true',
        help=
        "Whether to omit any links (or in other words, only use TF-IDF-based paragraphs)"
    )
    parser.add_argument("--pruning_by_links",
                        action='store_true',
                        help="Whether to do pruning by links (and top 1)")
    parser.add_argument(
        "--expand_links",
        action='store_true',
        help=
        "Whether to expand links with paragraphs in the same article (for NQ)")
    parser.add_argument(
        '--tfidf_limit',
        type=int,
        default=None,
        help=
        "Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)"
    )

    parser.add_argument("--pred_file",
                        default=None,
                        type=str,
                        help="File name to write paragraph selection results")
    parser.add_argument("--tagme",
                        action='store_true',
                        help="Whether to use tagme at inference")
    parser.add_argument(
        '--topk',
        type=int,
        default=2,
        help="Whether to use how many paragraphs from the previous steps")

    parser.add_argument(
        "--model_suffix",
        default=None,
        type=str,
        help="Suffix to load a model file ('pytorch_model_' + suffix +'.bin')")

    parser.add_argument("--db_save_path",
                        default=None,
                        type=str,
                        help="File path to DB")
    parser.add_argument("--fp16", default=False, action='store_true')
    parser.add_argument("--fp16_opt_level", default="O1", type=str)
    parser.add_argument("--do_label",
                        default=False,
                        action='store_true',
                        help="For pre-processing features only.")

    parser.add_argument("--oss_cache_dir", default=None, type=str)
    parser.add_argument("--cache_dir", default=None, type=str)
    parser.add_argument("--dist",
                        default=False,
                        action='store_true',
                        help='use distributed training.')
    parser.add_argument("--save_steps", default=5000, type=int)
    parser.add_argument("--resume", default=None, type=int)
    parser.add_argument("--oss_pretrain", default=None, type=str)
    parser.add_argument("--model_version", default='v1', type=str)
    parser.add_argument("--disable_rnn_layer_norm",
                        default=False,
                        action='store_true')

    args = parser.parse_args()

    if args.dist:
        dist.init_process_group(backend='nccl')
        print(f"local rank: {args.local_rank}")
        print(f"global rank: {dist.get_rank()}")
        print(f"world size: {dist.get_world_size()}")

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
        dist.init_process_group(backend='nccl')

    if args.dist:
        global_rank = dist.get_rank()
        world_size = dist.get_world_size()
        if world_size > 1:
            args.local_rank = global_rank

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if args.train_file_path is not None:
        do_train = True

        if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
            raise ValueError(
                "Output directory ({}) already exists and is not empty.".
                format(args.output_dir))
        if args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir, exist_ok=True)

    elif args.dev_file_path is not None:
        do_train = False

    else:
        raise ValueError(
            'One of train_file_path: {} or dev_file_path: {} must be non-None'.
            format(args.train_file_path, args.dev_file_path))

    processor = DataProcessor()

    # Configurations of the graph retriever
    graph_retriever_config = GraphRetrieverConfig(
        example_limit=args.example_limit,
        task=args.task,
        max_seq_length=args.max_seq_length,
        max_select_num=args.max_select_num,
        max_para_num=args.max_para_num,
        tfidf_limit=args.tfidf_limit,
        train_file_path=args.train_file_path,
        use_redundant=args.use_redundant,
        use_multiple_redundant=args.use_multiple_redundant,
        max_redundant_num=args.max_redundant_num,
        dev_file_path=args.dev_file_path,
        beam=args.beam,
        min_select_num=args.min_select_num,
        no_links=args.no_links,
        pruning_by_links=args.pruning_by_links,
        expand_links=args.expand_links,
        eval_chunk=args.eval_chunk,
        tagme=args.tagme,
        topk=args.topk,
        db_save_path=args.db_save_path,
        disable_rnn_layer_norm=args.disable_rnn_layer_norm)

    logger.info(graph_retriever_config)
    logger.info(args)

    tokenizer = AutoTokenizer.from_pretrained(args.bert_model)

    if args.model_version == 'roberta':
        from modeling_graph_retriever_roberta import RobertaForGraphRetriever
    elif args.model_version == 'v3':
        from modeling_graph_retriever_roberta import RobertaForGraphRetrieverIterV3 as RobertaForGraphRetriever
    else:
        raise RuntimeError()

    ##############################
    # Training                   #
    ##############################
    if do_train:
        _model_state_dict = None
        if args.oss_pretrain is not None:
            _model_state_dict = torch.load(load_pretrain_from_oss(
                args.oss_pretrain),
                                           map_location='cpu')
            logger.info(f"Loaded pretrained model from {args.oss_pretrain}")

        if args.resume is not None:
            _model_state_dict = torch.load(load_buffer_from_oss(
                os.path.join(args.oss_cache_dir,
                             f"pytorch_model_{args.resume}.bin")),
                                           map_location='cpu')

        model = RobertaForGraphRetriever.from_pretrained(
            args.bert_model,
            graph_retriever_config=graph_retriever_config,
            state_dict=_model_state_dict)

        model.to(device)

        global_step = 0

        POSITIVE = 1.0
        NEGATIVE = 0.0

        _cache_file_name = f"cache_roberta_train_{args.max_seq_length}_{args.max_para_num}"
        _examples_cache_file_name = f"examples_{_cache_file_name}"
        _features_cache_file_name = f"features_{_cache_file_name}"

        # Load training examples
        logger.info(f"Loading training examples and features.")
        try:
            if args.cache_dir is not None and os.path.exists(
                    os.path.join(args.cache_dir, _features_cache_file_name)):
                logger.info(
                    f"Loading pre-processed features from {os.path.join(args.cache_dir, _features_cache_file_name)}"
                )
                train_features = torch.load(
                    os.path.join(args.cache_dir, _features_cache_file_name))
            else:
                # train_examples = torch.load(load_buffer_from_oss(os.path.join(oss_features_cache_dir,
                #                                                               _examples_cache_file_name)))
                train_features = torch.load(
                    load_buffer_from_oss(
                        os.path.join(oss_features_cache_dir,
                                     _features_cache_file_name)))
                logger.info(
                    f"Pre-processed features are loaded from oss: "
                    f"{os.path.join(oss_features_cache_dir, _features_cache_file_name)}"
                )
        except:
            train_examples = processor.get_train_examples(
                graph_retriever_config)
            train_features = convert_examples_to_features(
                train_examples,
                args.max_seq_length,
                args.max_para_num,
                graph_retriever_config,
                tokenizer,
                train=True)
            logger.info(
                f"Saving pre-processed features into oss: {oss_features_cache_dir}"
            )
            torch_save_to_oss(
                train_examples,
                os.path.join(oss_features_cache_dir,
                             _examples_cache_file_name))
            torch_save_to_oss(
                train_features,
                os.path.join(oss_features_cache_dir,
                             _features_cache_file_name))

        if args.do_label:
            logger.info("Finished.")
            return

        # len(train_examples) and len(train_features) can be different, depending on the redundant setting
        num_train_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

        # Prepare optimizer
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'layer_norm']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        t_total = num_train_steps
        if args.local_rank != -1:
            t_total = t_total // dist.get_world_size()

        optimizer = AdamW(optimizer_grouped_parameters,
                          betas=(0.9, 0.98),
                          lr=args.learning_rate)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, int(t_total * args.warmup_proportion), t_total)

        logger.info(optimizer)
        if args.fp16:
            from apex import amp
            amp.register_half_function(torch, "einsum")

            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.fp16_opt_level)

        if args.local_rank != -1:
            if args.fp16_opt_level == 'O2':
                try:
                    import apex
                    model = apex.parallel.DistributedDataParallel(
                        model, delay_allreduce=True)
                except ImportError:
                    model = torch.nn.parallel.DistributedDataParallel(
                        model, find_unused_parameters=True)
            else:
                model = torch.nn.parallel.DistributedDataParallel(
                    model, find_unused_parameters=True)

        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        if args.resume is not None:
            _amp_state_dict = os.path.join(args.oss_cache_dir,
                                           f"amp_{args.resume}.bin")
            _optimizer_state_dict = os.path.join(
                args.oss_cache_dir, f"optimizer_{args.resume}.pt")
            _scheduler_state_dict = os.path.join(
                args.oss_cache_dir, f"scheduler_{args.resume}.pt")

            amp.load_state_dict(
                torch.load(load_buffer_from_oss(_amp_state_dict)))
            optimizer.load_state_dict(
                torch.load(load_buffer_from_oss(_optimizer_state_dict)))
            scheduler.load_state_dict(
                torch.load(load_buffer_from_oss(_scheduler_state_dict)))

            logger.info(f"Loaded resumed state dict of step {args.resume}")

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_features))
        logger.info("  Instantaneous batch size per GPU = %d",
                    args.train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            args.train_batch_size * args.gradient_accumulation_steps *
            (dist.get_world_size() if args.local_rank != -1 else 1),
        )
        logger.info("  Gradient Accumulation steps = %d",
                    args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        model.train()
        epc = 0
        # test
        if args.local_rank in [-1, 0]:
            if args.fp16:
                amp_file = os.path.join(args.oss_cache_dir,
                                        f"amp_{global_step}.bin")
                torch_save_to_oss(amp.state_dict(), amp_file)
            optimizer_file = os.path.join(args.oss_cache_dir,
                                          f"optimizer_{global_step}.pt")
            torch_save_to_oss(optimizer.state_dict(), optimizer_file)
            scheduler_file = os.path.join(args.oss_cache_dir,
                                          f"scheduler_{global_step}.pt")
            torch_save_to_oss(scheduler.state_dict(), scheduler_file)

        tr_loss = 0
        for _ in range(int(args.num_train_epochs)):
            logger.info('Epoch ' + str(epc + 1))

            TOTAL_NUM = len(train_features)
            train_start_index = 0
            CHUNK_NUM = 8
            train_chunk = TOTAL_NUM // CHUNK_NUM
            chunk_index = 0

            random.shuffle(train_features)

            save_retry = False
            while train_start_index < TOTAL_NUM:
                train_end_index = min(train_start_index + train_chunk - 1,
                                      TOTAL_NUM - 1)
                chunk_len = train_end_index - train_start_index + 1

                if args.resume is not None and global_step < args.resume:
                    _chunk_steps = int(
                        math.ceil(chunk_len * 1.0 / args.train_batch_size /
                                  (1 if args.local_rank == -1 else
                                   dist.get_world_size())))
                    _chunk_steps = _chunk_steps // args.gradient_accumulation_steps
                    if global_step + _chunk_steps <= args.resume:
                        global_step += _chunk_steps
                        train_start_index = train_end_index + 1
                        continue

                train_features_ = train_features[
                    train_start_index:train_start_index + chunk_len]

                all_input_ids = torch.tensor(
                    [f.input_ids for f in train_features_], dtype=torch.long)
                all_input_masks = torch.tensor(
                    [f.input_masks for f in train_features_], dtype=torch.long)
                all_segment_ids = torch.tensor(
                    [f.segment_ids for f in train_features_], dtype=torch.long)
                all_output_masks = torch.tensor(
                    [f.output_masks for f in train_features_],
                    dtype=torch.float)
                all_num_paragraphs = torch.tensor(
                    [f.num_paragraphs for f in train_features_],
                    dtype=torch.long)
                all_num_steps = torch.tensor(
                    [f.num_steps for f in train_features_], dtype=torch.long)
                train_data = TensorDataset(all_input_ids, all_input_masks,
                                           all_segment_ids, all_output_masks,
                                           all_num_paragraphs, all_num_steps)

                if args.local_rank != -1:
                    train_sampler = torch.utils.data.DistributedSampler(
                        train_data)
                else:
                    train_sampler = RandomSampler(train_data)
                train_dataloader = DataLoader(train_data,
                                              sampler=train_sampler,
                                              batch_size=args.train_batch_size,
                                              pin_memory=True,
                                              num_workers=4)

                if args.local_rank != -1:
                    train_dataloader.sampler.set_epoch(epc)

                logger.info('Examples from ' + str(train_start_index) +
                            ' to ' + str(train_end_index))
                for step, batch in enumerate(
                        tqdm(train_dataloader,
                             desc="Iteration",
                             disable=args.local_rank not in [-1, 0])):
                    if args.resume is not None and global_step < args.resume:
                        if (step + 1) % args.gradient_accumulation_steps == 0:
                            global_step += 1
                        continue

                    input_masks = batch[1]
                    batch_max_len = input_masks.sum(dim=2).max().item()

                    num_paragraphs = batch[4]
                    batch_max_para_num = num_paragraphs.max().item()

                    num_steps = batch[5]
                    batch_max_steps = num_steps.max().item()

                    # output_masks_cpu = (batch[3])[:, :batch_max_steps, :batch_max_para_num + 1]

                    batch = tuple(t.to(device) for t in batch)
                    input_ids, input_masks, segment_ids, output_masks, _, _ = batch
                    B = input_ids.size(0)

                    input_ids = input_ids[:, :batch_max_para_num, :
                                          batch_max_len]
                    input_masks = input_masks[:, :batch_max_para_num, :
                                              batch_max_len]
                    segment_ids = segment_ids[:, :batch_max_para_num, :
                                              batch_max_len]
                    output_masks = output_masks[:, :batch_max_steps, :
                                                batch_max_para_num +
                                                1]  # 1 for EOE

                    target = torch.zeros(output_masks.size()).fill_(
                        NEGATIVE)  # (B, NUM_STEPS, |P|+1) <- 1 for EOE
                    for i in range(B):
                        output_masks[i, :num_steps[i], -1] = 1.0  # for EOE

                        for j in range(num_steps[i].item() - 1):
                            target[i, j, j].fill_(POSITIVE)

                        target[i, num_steps[i] - 1, -1].fill_(POSITIVE)
                    target = target.to(device)

                    neg_start = batch_max_steps - 1
                    while neg_start < batch_max_para_num:
                        neg_end = min(neg_start + args.neg_chunk - 1,
                                      batch_max_para_num - 1)
                        neg_len = (neg_end - neg_start + 1)

                        input_ids_ = torch.cat(
                            (input_ids[:, :batch_max_steps - 1, :],
                             input_ids[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        input_masks_ = torch.cat(
                            (input_masks[:, :batch_max_steps - 1, :],
                             input_masks[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        segment_ids_ = torch.cat(
                            (segment_ids[:, :batch_max_steps - 1, :],
                             segment_ids[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        output_masks_ = torch.cat(
                            (output_masks[:, :, :batch_max_steps - 1],
                             output_masks[:, :, neg_start:neg_start + neg_len],
                             output_masks[:, :, batch_max_para_num:
                                          batch_max_para_num + 1]),
                            dim=2)
                        target_ = torch.cat(
                            (target[:, :, :batch_max_steps - 1],
                             target[:, :, neg_start:neg_start + neg_len],
                             target[:, :,
                                    batch_max_para_num:batch_max_para_num +
                                    1]),
                            dim=2)

                        if neg_start != batch_max_steps - 1:
                            output_masks_[:, :, :batch_max_steps - 1] = 0.0
                            output_masks_[:, :, -1] = 0.0

                        loss = model(input_ids_, segment_ids_, input_masks_,
                                     output_masks_, target_, batch_max_steps)

                        if n_gpu > 1:
                            loss = loss.mean(
                            )  # mean() to average on multi-gpu.
                        if args.gradient_accumulation_steps > 1:
                            loss = loss / args.gradient_accumulation_steps

                        if args.fp16:
                            with amp.scale_loss(loss,
                                                optimizer) as scaled_loss:
                                scaled_loss.backward()
                        else:
                            loss.backward()

                        tr_loss += loss.item()
                        neg_start = neg_end + 1

                        # del input_ids_
                        # del input_masks_
                        # del segment_ids_
                        # del output_masks_
                        # del target_

                    if (step + 1) % args.gradient_accumulation_steps == 0:

                        if args.fp16:
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(optimizer), 1.0)
                        else:
                            torch.nn.utils.clip_grad_norm_(
                                model.parameters(), 1.0)

                        optimizer.step()
                        scheduler.step()
                        # optimizer.zero_grad()
                        model.zero_grad()
                        global_step += 1

                        if global_step % 50 == 0:
                            _cur_steps = global_step if args.resume is None else global_step - args.resume
                            logger.info(
                                f"Training loss: {tr_loss / _cur_steps}\t"
                                f"Learning rate: {scheduler.get_lr()[0]}\t"
                                f"Global step: {global_step}")

                        if global_step % args.save_steps == 0:
                            if args.local_rank in [-1, 0]:
                                model_to_save = model.module if hasattr(
                                    model, 'module') else model
                                output_model_file = os.path.join(
                                    args.oss_cache_dir,
                                    f"pytorch_model_{global_step}.bin")
                                torch_save_to_oss(model_to_save.state_dict(),
                                                  output_model_file)

                            _suffix = "" if args.local_rank == -1 else f"_{args.local_rank}"
                            if args.fp16:
                                amp_file = os.path.join(
                                    args.oss_cache_dir,
                                    f"amp_{global_step}{_suffix}.bin")
                                torch_save_to_oss(amp.state_dict(), amp_file)
                            optimizer_file = os.path.join(
                                args.oss_cache_dir,
                                f"optimizer_{global_step}{_suffix}.pt")
                            torch_save_to_oss(optimizer.state_dict(),
                                              optimizer_file)
                            scheduler_file = os.path.join(
                                args.oss_cache_dir,
                                f"scheduler_{global_step}{_suffix}.pt")
                            torch_save_to_oss(scheduler.state_dict(),
                                              scheduler_file)

                            logger.info(
                                f"checkpoint of step {global_step} is saved to oss."
                            )

                    # del input_ids
                    # del input_masks
                    # del segment_ids
                    # del output_masks
                    # del target
                    # del batch

                chunk_index += 1
                train_start_index = train_end_index + 1

                # Save the model at the half of the epoch
                if (chunk_index == CHUNK_NUM // 2
                        or save_retry) and args.local_rank in [-1, 0]:
                    status = save(model, args.output_dir, str(epc + 0.5))
                    save_retry = (not status)

                del train_features_
                del all_input_ids
                del all_input_masks
                del all_segment_ids
                del all_output_masks
                del all_num_paragraphs
                del all_num_steps
                del train_data
                del train_sampler
                del train_dataloader
                gc.collect()

            # Save the model at the end of the epoch
            if args.local_rank in [-1, 0]:
                save(model, args.output_dir, str(epc + 1))
                # model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
                # output_model_file = os.path.join(args.oss_cache_dir, "pytorch_model_" + str(epc + 1) + ".bin")
                # torch_save_to_oss(model_to_save.state_dict(), output_model_file)

            epc += 1

    if do_train:
        return

    ##############################
    # Evaluation                 #
    ##############################
    assert args.model_suffix is not None

    if graph_retriever_config.db_save_path is not None:
        import sys
        sys.path.append('../')
        from pipeline.tfidf_retriever import TfidfRetriever
        tfidf_retriever = TfidfRetriever(graph_retriever_config.db_save_path,
                                         None)
    else:
        tfidf_retriever = None

    if args.oss_cache_dir is not None:
        file_name = 'pytorch_model_' + args.model_suffix + '.bin'
        model_state_dict = torch.load(
            load_buffer_from_oss(os.path.join(args.oss_cache_dir, file_name)))
    else:
        model_state_dict = load(args.output_dir, args.model_suffix)

    model = RobertaForGraphRetriever.from_pretrained(
        args.bert_model,
        state_dict=model_state_dict,
        graph_retriever_config=graph_retriever_config)
    model.to(device)

    model.eval()

    if args.pred_file is not None:
        pred_output = []

    eval_examples = processor.get_dev_examples(graph_retriever_config)

    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)

    TOTAL_NUM = len(eval_examples)
    eval_start_index = 0

    while eval_start_index < TOTAL_NUM:
        eval_end_index = min(
            eval_start_index + graph_retriever_config.eval_chunk - 1,
            TOTAL_NUM - 1)
        chunk_len = eval_end_index - eval_start_index + 1

        eval_features = convert_examples_to_features(
            eval_examples[eval_start_index:eval_start_index + chunk_len],
            args.max_seq_length, args.max_para_num, graph_retriever_config,
            tokenizer)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_masks = torch.tensor([f.input_masks for f in eval_features],
                                       dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_output_masks = torch.tensor(
            [f.output_masks for f in eval_features], dtype=torch.float)
        all_num_paragraphs = torch.tensor(
            [f.num_paragraphs for f in eval_features], dtype=torch.long)
        all_num_steps = torch.tensor([f.num_steps for f in eval_features],
                                     dtype=torch.long)
        all_ex_indices = torch.tensor([f.ex_index for f in eval_features],
                                      dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_masks,
                                  all_segment_ids, all_output_masks,
                                  all_num_paragraphs, all_num_steps,
                                  all_ex_indices)

        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        for input_ids, input_masks, segment_ids, output_masks, num_paragraphs, num_steps, ex_indices in tqdm(
                eval_dataloader, desc="Evaluating"):
            batch_max_len = input_masks.sum(dim=2).max().item()
            batch_max_para_num = num_paragraphs.max().item()

            batch_max_steps = num_steps.max().item()

            input_ids = input_ids[:, :batch_max_para_num, :batch_max_len]
            input_masks = input_masks[:, :batch_max_para_num, :batch_max_len]
            segment_ids = segment_ids[:, :batch_max_para_num, :batch_max_len]
            output_masks = output_masks[:, :batch_max_para_num +
                                        2, :batch_max_para_num + 1]
            output_masks[:, 1:, -1] = 1.0  # Ignore EOE in the first step

            input_ids = input_ids.to(device)
            input_masks = input_masks.to(device)
            segment_ids = segment_ids.to(device)
            output_masks = output_masks.to(device)

            examples = [
                eval_examples[eval_start_index + ex_indices[i].item()]
                for i in range(input_ids.size(0))
            ]

            with torch.no_grad():
                pred, prob, topk_pred, topk_prob = model.beam_search(
                    input_ids,
                    segment_ids,
                    input_masks,
                    examples=examples,
                    tokenizer=tokenizer,
                    retriever=tfidf_retriever,
                    split_chunk=args.split_chunk)

            for i in range(len(pred)):
                e = examples[i]
                titles = [e.title_order[p] for p in pred[i]]

                # Output predictions to a file
                if args.pred_file is not None:
                    pred_output.append({})
                    pred_output[-1]['q_id'] = e.guid

                    pred_output[-1]['titles'] = titles
                    pred_output[-1]['probs'] = []
                    for prob_ in prob[i]:
                        entry = {'EOE': prob_[-1]}
                        for j in range(len(e.title_order)):
                            entry[e.title_order[j]] = prob_[j]
                        pred_output[-1]['probs'].append(entry)

                    topk_titles = [[e.title_order[p] for p in topk_pred[i][j]]
                                   for j in range(len(topk_pred[i]))]
                    pred_output[-1]['topk_titles'] = topk_titles

                    topk_probs = []
                    for k in range(len(topk_prob[i])):
                        topk_probs.append([])
                        for prob_ in topk_prob[i][k]:
                            entry = {'EOE': prob_[-1]}
                            for j in range(len(e.title_order)):
                                entry[e.title_order[j]] = prob_[j]
                            topk_probs[-1].append(entry)
                    pred_output[-1]['topk_probs'] = topk_probs

                    # Output the selected paragraphs
                    context = {}
                    for ts in topk_titles:
                        for t in ts:
                            context[t] = e.all_paras[t]
                    pred_output[-1]['context'] = context

        eval_start_index = eval_end_index + 1

        del eval_features
        del all_input_ids
        del all_input_masks
        del all_segment_ids
        del all_output_masks
        del all_num_paragraphs
        del all_num_steps
        del all_ex_indices
        del eval_data

    if args.pred_file is not None:
        json.dump(pred_output, open(args.pred_file, 'w'))
def train(config, model, train_iter, dev_iter, test_iter):
    start_time = time.time()
    model.train()
    param_optimizer = list(model.named_parameters())
    #no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

    #LayerNorm,bias是不需要decay的
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in param_optimizer if "bert" in n],
            "lr": config.bert_learning_rate,
            'weight_decay': 0.01
        },
        {
            'params': [p for n, p in param_optimizer if "bert" not in n],
            "lr": config.other_learning_rate,
            'weight_decay': 0.01
        },
    ]

    # LayerNorm,bias是不需要decay的
    # no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    # optimizer_grouped_parameters = [
    #     {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
    #      'weight_decay': config.weight_decay},
    #     {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    # ]

    #optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    #optimizer = BertAdam(optimizer_grouped_parameters, lr=config.learning_rate,  warmup=0.05,t_total=len(train_iter) * config.num_epochs)

    optimizer = AdamW(
        optimizer_grouped_parameters,
        lr=config.bert_learning_rate,
        correct_bias=False
    )  # To reproduce BertAdam specific behavior set correct_bias=False

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0.05,
        num_training_steps=len(train_iter) *
        config.num_epochs)  # PyTorch scheduler

    #optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate, correct_bias=False)
    total_batch = 0  # 记录进行到多少batch
    dev_best_loss = float('inf')
    last_improve = 0  # 记录上次验证集loss下降的batch数
    flag = False  # 记录是否很久没有效果提升
    model.train()
    for epoch in range(config.num_epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
        for i, data in enumerate(train_iter):
            data = tuple(t.to(config.device) for t in data)
            input_ids, input_mask, segment_ids, label_ids = data  #
            outputs = model(data)
            model.zero_grad()
            loss = F.cross_entropy(outputs, label_ids)
            loss.backward()

            optimizer.step()
            scheduler.step()

            if total_batch % 100 == 0:
                # 每多少轮输出在训练集和验证集上的效果
                true = label_ids.data.cpu()
                predic = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(true, predic)
                dev_acc, dev_loss = evaluate(config, model, dev_iter)
                if dev_loss < dev_best_loss:
                    dev_best_loss = dev_loss
                    torch.save(model.state_dict(), config.save_path)
                    improve = '*'
                    last_improve = total_batch
                else:
                    improve = ''
                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}'
                print(
                    msg.format(total_batch, loss.item(), train_acc, dev_loss,
                               dev_acc, time_dif, improve))
                model.train()
            total_batch += 1
            if total_batch - last_improve > config.require_improvement:
                # 验证集loss超过1000batch没下降,结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        if flag:
            break
class LengthDropTrainer(Trainer):
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer = None,
        best_metric: str = 'acc',
        length_drop_args: LengthDropArguments = None,
        **kwargs,
    ):
        super(LengthDropTrainer, self).__init__(**kwargs)
        self.tokenizer = tokenizer
        self.best_metric = best_metric
        if length_drop_args is None:
            length_drop_args = LengthDropArguments()
        self.length_drop_args = length_drop_args

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Setup the optimizer and the learning rate scheduler.
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
        """
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            retention_params = []
            wd_params = []
            no_wd_params = []
            for n, p in self.model.named_parameters():
                if "retention" in n:
                    retention_params.append(p)
                elif any(nd in n for nd in no_decay):
                    no_wd_params.append(p)
                else:
                    wd_params.append(p)
            optimizer_grouped_parameters = [
                {"params": wd_params, "weight_decay": self.args.weight_decay, "lr": self.args.learning_rate},
                {"params": no_wd_params, "weight_decay": 0.0, "lr": self.args.learning_rate}
            ]
            if len(retention_params) > 0:
                optimizer_grouped_parameters.append(
                    {"params": retention_params, "weight_decay": 0.0, "lr": self.length_drop_args.lr_soft_extract}
                )
            self.optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate,
                betas=(self.args.adam_beta1, self.args.adam_beta2),
                eps=self.args.adam_epsilon,
            )
        if self.lr_scheduler is None:
            if self.args.warmup_ratio is not None:
                num_warmup_steps = int(self.args.warmup_ratio * num_training_steps)
            else:
                num_warmup_steps = self.args.warmup_steps
            self.lr_scheduler = get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
            )

    def div_loss(self, loss):
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps
        return loss

    def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
        """
        Main training entry point.
        Args:
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
            trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
                The trial run or the hyperparameter dictionary for hyperparameter search.
        """
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

        # Model re-init
        if self.model_init is not None:
            # Seed must be set before instantiating the model when using model_init.
            set_seed(self.args.seed)
            model = self.model_init()
            self.model = model.to(self.args.device)

            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None

        # Data loader and number of training steps
        train_dataloader = self.get_train_dataloader()
        num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                self.args.max_steps % num_update_steps_per_epoch > 0
            )
        else:
            t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs
            self.args.max_steps = t_total

        self.create_optimizer_and_scheduler(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            self.optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16 and _use_apex:
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})

        # Train!
        if is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])

                epochs_trained = self.global_step // num_update_steps_per_epoch
                steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", self.global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss_sum = 0.0
        loss_sum = defaultdict(float)
        best = {self.best_metric: None}
        model.zero_grad()
        disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
        train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
        for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
                epoch_iterator = parallel_loader
            else:
                epoch_iterator = train_dataloader

            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

            epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    epoch_pbar.update(1)
                    continue

                model.train()
                inputs = self._prepare_inputs(inputs)

                inputs["output_attentions"] = self.length_drop_args.length_config is not None

                layer_config = sample_layer_configuration(
                    model.config.num_hidden_layers,
                    layer_dropout_prob=self.length_drop_args.layer_dropout_prob,
                    layer_dropout=0,
                )
                inputs["layer_config"] = layer_config

                inputs["length_config"] = self.length_drop_args.length_config

                outputs = model(**inputs)
                # Save past state if it exists
                if self.args.past_index >= 0:
                    self._past = outputs[self.args.past_index]
                task_loss = self.div_loss(outputs[0])
                if self.length_drop_args.length_adaptive:
                    loss_sum["full"] += task_loss.item()
                loss = task_loss
                if self.length_drop_args.length_adaptive:
                    loss = loss / (self.length_drop_args.num_sandwich + 2)

                tr_loss_sum += loss.item()
                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # inplace distillation
                if self.length_drop_args.length_adaptive:
                    logits = outputs[1].detach()

                    for i in range(self.length_drop_args.num_sandwich + 1):
                        inputs["output_attentions"] = True

                        layer_config = sample_layer_configuration(
                            model.config.num_hidden_layers,
                            layer_dropout_prob=self.length_drop_args.layer_dropout_prob,
                            layer_dropout=(self.length_drop_args.layer_dropout_bound if i == 0 else None),
                            layer_dropout_bound=self.length_drop_args.layer_dropout_bound,
                        )
                        inputs["layer_config"] = layer_config

                        length_config = sample_length_configuration(
                            self.args.max_seq_length,
                            model.config.num_hidden_layers,
                            layer_config,
                            length_drop_ratio=(self.length_drop_args.length_drop_ratio_bound if i == 0 else None),
                            length_drop_ratio_bound=self.length_drop_args.length_drop_ratio_bound,
                        )
                        inputs["length_config"] = length_config

                        outputs_sub = model(**inputs)
                        task_loss_sub = self.div_loss(outputs_sub[0])
                        if i == 0:
                            loss_sum["smallest"] += task_loss_sub.item()
                            loss_sum["sub"] += 0
                        else:
                            loss_sum["sub"] += task_loss_sub.item() / self.length_drop_args.num_sandwich

                        logits_sub = outputs_sub[1]
                        loss_fct = KLDivLoss(reduction="batchmean")
                        kl_loss = loss_fct(F.log_softmax(logits, -1), F.softmax(logits_sub, -1))
                        loss = self.div_loss(kl_loss)
                        loss_sum["kl"] += loss.item() / (self.length_drop_args.num_sandwich + 1)
                        loss = loss / (self.length_drop_args.num_sandwich + 2)

                        tr_loss_sum += loss.item()
                        if self.args.fp16 and _use_native_amp:
                            self.scaler.scale(loss).backward()
                        elif self.args.fp16 and _use_apex:
                            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                                scaled_loss.backward()
                        else:
                            loss.backward()

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    (step + 1) == len(epoch_iterator) <= self.args.gradient_accumulation_steps
                ):
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                    if is_torch_tpu_available():
                        xm.optimizer_step(self.optimizer)
                    elif self.args.fp16 and _use_native_amp:
                        self.scaler.step(self.optimizer)
                        self.scaler.update()
                    else:
                        self.optimizer.step()

                    self.lr_scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        # backward compatibility for pytorch schedulers
                        lr = (
                            self.lr_scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >= version.parse("1.4")
                            else self.lr_scheduler.get_lr()[0]
                        )
                        loss = tr_loss_sum / self.args.logging_steps
                        tr_loss_sum = 0.0
                        logs = {"lr": lr, "loss": loss}
                        log_str = f"[{self.global_step:5d}] lr {lr:g} | loss {loss:2.3f}"

                        for key, value in loss_sum.items():
                            value /= self.args.logging_steps
                            loss_sum[key] = 0.0
                            logs[f"{key}_loss"] = value
                            log_str += f" | {key}_loss {value:2.3f}"

                        self.log(logs, "train")
                        logger.info(log_str)

                    '''
                    if (
                        self.args.evaluation_strategy == EvaluationStrategy.STEPS
                        and self.global_step % self.args.eval_steps == 0
                    ):
                        results = self.evaluate()
                        self._report_to_hp_search(trial, epoch, results)
                    '''

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert (
                                model.module is self.model
                            ), f"Module {model.module} should be a reference to self.model"
                        else:
                            assert model is self.model, f"Model {model} should be a reference to self.model"

                        if self.args.evaluate_during_training:
                            results = self.evaluate()
                            results = {k[5:]: v for k, v in results.items() if k.startswith("eval_")}
                            self.log(results, "dev")
                            msg = " | ".join([f"{k} {v:.3f}" for k, v in results.items()])
                            logger.info(f"  [{self.global_step:5d}] {msg}")

                        # Save model checkpoint
                        if self.args.save_only_best:
                            output_dirs = []
                        else:
                            checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
                            if self.hp_search_backend is not None and trial is not None:
                                run_id = (
                                    trial.number
                                    if self.hp_search_backend == HPSearchBackend.OPTUNA
                                    else tune.get_trial_id()
                                )
                                checkpoint_folder += f"-run-{run_id}"
                            output_dirs = [os.path.join(self.args.output_dir, checkpoint_folder)]
                            
                        if self.args.evaluate_during_training:
                            if best[self.best_metric] is None or results[self.best_metric] > best[self.best_metric]:
                                logger.info("Congratulations, best model so far!")
                                output_dirs.append(os.path.join(self.args.output_dir, "checkpoint-best"))
                                best = results

                        for output_dir in output_dirs:
                            self.save_model(output_dir)

                            if self.is_world_master() and self.tokenizer is not None:
                                self.tokenizer.save_pretrained(output_dir)

                            if self.is_world_process_zero():
                                self._rotate_checkpoints(use_mtime=True)

                            '''
                            if is_torch_tpu_available():
                                xm.rendezvous("saving_optimizer_states")
                                xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                            elif self.is_world_process_zero():
                                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                            '''

                epoch_pbar.update(1)
                if 0 < self.args.max_steps <= self.global_step:
                    break
            epoch_pbar.close()
            train_pbar.update(1)

            '''
            if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
                results = self.evaluate()
                self._report_to_hp_search(trial, epoch, results)
            '''

            if self.args.tpu_metrics_debug or self.args.debug:
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
            if 0 < self.args.max_steps <= self.global_step:
                break

        train_pbar.close()
        if self.tb_writer:
            self.tb_writer.close()
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        return self.global_step, best

    def log(self, logs, mode="train"):
        self._setup_loggers()
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
        if self.tb_writer:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
            self.tb_writer.flush()
        if is_wandb_available():
            if self.is_world_process_zero():
                wandb.log(logs, step=self.global_step)
        if is_comet_available():
            if self.is_world_process_zero():
                experiment = comet_ml.config.get_global_experiment()
                if experiment is not None:
                    experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
        output = {**logs, **{"step": self.global_step}}
        if self.is_world_process_zero():
            self.log_history.append(output)

    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
        """
        Run evaluation and returns metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent (pass it to the init :obj:`compute_metrics` argument).

        You can also subclass and override this method to inject custom behavior.

        Args:
            eval_dataset (:obj:`Dataset`, `optional`):
                Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
                columns not accepted by the ``model.forward()`` method are automatically removed.

        Returns:
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self.prediction_loop(eval_dataloader, description="Evaluation")

        if self.args.tpu_metrics_debug or self.args.debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        return output.metrics

    def prediction_loop(
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
    ) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.

        Works both with or without labels.
        """
        if hasattr(self, "_prediction_loop"):
            warnings.warn(
                "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
                FutureWarning,
            )
            return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)

        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )

        '''
        assert not getattr(
            self.model.config, "output_attentions", False
        ), "The prediction loop does not work with `output_attentions=True`."
        assert not getattr(
            self.model.config, "output_hidden_states", False
        ), "The prediction loop does not work with `output_hidden_states=True`."
        '''

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = self.model
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        '''
        batch_size = dataloader.batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
        '''
        eval_losses: List[float] = []
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
        model.eval()

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

        if self.args.past_index >= 0:
            self._past = None

        disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
        for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
            batch_size = inputs[list(inputs.keys())[0]].shape[0]
            if loss is not None:
                eval_losses.extend([loss] * batch_size)
            if logits is not None:
                preds = logits if preds is None else nested_concat(preds, logits, dim=0)
            if labels is not None:
                label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
        elif is_torch_tpu_available():
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
            if preds is not None:
                preds = nested_xla_mesh_reduce(preds, "eval_preds")
            if label_ids is not None:
                label_ids = nested_xla_mesh_reduce(label_ids, "eval_label_ids")
            if eval_losses is not None:
                eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = nested_numpify(preds)
        if label_ids is not None:
            label_ids = nested_numpify(label_ids)

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            if self.args.local_rank != -1:
                metrics["eval_loss"] = (
                    distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
                    .mean()
                    .item()
                )
            else:
                metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)

    def prediction_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
            A tuple with the loss, logits and labels (each being optional).
        """
        has_labels = all(inputs.get(k) is not None for k in self.args.label_names)
        inputs = self._prepare_inputs(inputs)

        output_attentions = getattr(inputs, 'output_attentions', None)
        output_hidden_states = getattr(inputs, 'output_hidden_states', None)

        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states

        num_additional_outputs = int(output_attentions == True) + int(output_hidden_states == True)

        with torch.no_grad():
            outputs = model(**inputs)
            if has_labels:
                # The .mean() is to reduce in case of distributed training
                loss = outputs[0].mean().item()
                logits = outputs[1:(len(outputs) - num_additional_outputs)]
            else:
                loss = None
                # Slicing so we get a tuple even if `outputs` is a `ModelOutput`.
                logits = outputs[:(len(outputs) - num_additional_outputs)]
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

        logits = tuple(logit.detach() for logit in logits)
        if len(logits) == 1:
            logits = logits[0]

        if has_labels:
            labels = tuple(inputs.get(name).detach() for name in self.args.label_names)
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        return (loss, logits, labels)

    def init_evolution(self, lower_constraint=0, upper_constraint=None):
        size = (1, self.args.max_seq_length)
        self.dummy_inputs = (
            torch.ones(size, dtype=torch.long).to(self.args.device),
            torch.ones(size, dtype=torch.long).to(self.args.device),
            torch.zeros(size, dtype=torch.long).to(self.args.device),
        )
        if self.model.config.model_type == "distilbert":
            self.dummy_inputs = self.dummy_inputs[:2]


        self.lower_constraint = lower_constraint
        self.upper_constraint = upper_constraint

        self.store = {}  # gene: (macs, score, method, parent(s))
        self.population = []

    def load_store(self, store_file):
        if not os.path.isfile(store_file):
            return
        with open(store_file, 'r') as f:
            for row in csv.reader(f, delimiter='\t'):
                row = tuple(eval(x) for x in row[:3])
                self.store[row[0]] = row[1:3] + (0, None)

    def save_store(self, store_file):
        store_keys = sorted(self.store.keys(), key=lambda x: self.store[x][0])
        with open(store_file, 'w') as f:
            writer = csv.writer(f, delimiter='\t')
            for gene in store_keys:
                writer.writerow([str(gene)] + [str(x) for x in self.store[gene]])

    def save_population(self, population_file, population):
        with open(population_file, 'w') as f:
            writer = csv.writer(f, delimiter='\t')
            for gene in population:
                writer.writerow([str(gene)] + [str(x) for x in self.store[gene]])

    def ccw(self, gene0, gene1, gene2):
        x0, y0 = self.store[gene0][:2]
        x1, y1 = self.store[gene1][:2]
        x2, y2 = self.store[gene2][:2]
        return (x0 * y1 + x1 * y2 + x2 * y0) - (x0 * y2 + x1 * y0 + x2 * y1)

    def convex_hull(self):
        hull = self.population[:2]
        for gene in self.population[2:]:
            if self.store[hull[-1]][1] >= self.store[gene][1]:
                continue
            while len(hull) >= 2 and self.ccw(hull[-2], hull[-1], gene) >= 0:
                del hull[-1]
            hull.append(gene)
        return hull

    def pareto_frontier(self):
        self.population = sorted(self.population, key=lambda x: self.store[x][:2])

        frontier = [self.population[0]]
        for gene in self.population[1:-1]:
            if self.store[gene][1] > self.store[frontier[-1]][1]:
                if self.store[gene][0] == frontier[-1][0]:
                    del frontier[-1]
                frontier.append(gene)
        frontier.append(self.population[-1])
        self.population = frontier

        area = 0
        for gene0, gene1 in zip(self.population[:-1], self.population[1:]):
            x0, y0 = self.store[gene0][:2]
            x1, y1 = self.store[gene1][:2]
            area += (x1 - x0) * y0
        area /= (self.upper_constraint - self.lower_constraint)
        return self.population, area

    def add_gene(self, gene, macs=None, score=None, method=0, parents=None):
        if gene not in self.store:
            self.model.eval()
            if self.model.config.model_type == "distilbert":
                bert = self.model.distilbert
            else:
                assert hasattr(self.model, "bert")
                bert = self.model.bert
            bert.set_length_config(gene)
            macs = macs or torchprofile.profile_macs(self.model, args=self.dummy_inputs)
            # logger.info(gene, macs)
            if macs < self.lower_constraint:
                return False
            score = score or self.evaluate()["eval_" + self.best_metric]
            self.store[gene] = (macs, score, method, parents)
            logger.info(store2str(gene, macs, score, method, parents))

        macs = self.store[gene][0]
        if macs >= self.lower_constraint \
                and (self.upper_constraint is None or macs <= self.upper_constraint) \
                and gene not in self.population:
            self.population.append(gene)
            return True
        return False

    def mutate(self, mutation_prob):
        gene = random.choice(self.population)
        mutated_gene = ()
        for i in range(self.model.config.num_hidden_layers):
            if np.random.uniform() < mutation_prob:
                prev = (self.args.max_seq_length if i == 0 else mutated_gene[i - 1])
                next = (2 if i == self.model.config.num_hidden_layers - 1 else gene[i + 1])
                mutated_gene += (random.randrange(next, prev + 1),)
            else:
                mutated_gene += (gene[i],)
        return self.add_gene(mutated_gene, method=1, parents=(gene,))

    def crossover(self):
        gene0, gene1 = random.sample(self.population, 2)
        crossovered_gene = tuple((g0 + g1 + 1) // 2 for g0, g1 in zip(gene0, gene1))
        return self.add_gene(crossovered_gene, method=2, parents=(gene0, gene1))
Beispiel #16
0
def main():
    parser = argparse.ArgumentParser()
    ## Required parameters
    ###############
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument("--pretrain_model",
                        default='bert-case-uncased',
                        type=str,
                        required=True,
                        help="Pre-trained model")
    parser.add_argument("--num_labels_task",
                        default=None,
                        type=int,
                        required=True,
                        help="num_labels_task")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        default=False,
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--task",
                        default=None,
                        type=int,
                        required=True,
                        help="Choose Task")
    ###############

    args = parser.parse_args()

    processors = Processor_1

    num_labels = args.num_labels_task

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {}, n_gpu: {}, distributed training: {}, 16-bits training: {}"
        .format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model)

    train_examples = None
    num_train_steps = None
    aspect_list = None
    sentiment_list = None
    processor = processors()
    num_labels = num_labels
    train_examples, aspect_list, sentiment_list = processor.get_train_examples(
        args.data_dir)

    if args.task == 1:
        num_labels = len(aspect_list)
    elif args.task == 2:
        num_labels = len(sentiment_list)
    else:
        print("What's task?")
        exit()

    num_train_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    #model = RobertaForSequenceClassification.from_pretrained(args.pretrain_model, num_labels=args.num_labels_task, output_hidden_states=False, output_attentions=False, return_dict=True)
    model = RobertaForMaskedLMDomainTask.from_pretrained(
        args.pretrain_model,
        num_labels=args.num_labels_task,
        output_hidden_states=False,
        output_attentions=False,
        return_dict=True)

    # Prepare optimizer
    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()

    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    #no_decay = ['bias', 'LayerNorm.weight']
    no_grad = [
        'bert.encoder.layer.11.output.dense_ent',
        'bert.encoder.layer.11.output.LayerNorm_ent'
    ]
    param_optimizer = [(n, p) for n, p in param_optimizer
                       if not any(nd in n for nd in no_grad)]
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        args.weight_decay
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=int(t_total *
                                                                     0.1),
                                                num_training_steps=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
            exit()

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    global_step = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      aspect_list,
                                                      sentiment_list,
                                                      args.max_seq_length,
                                                      tokenizer, args.task)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)

        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_attention_mask = torch.tensor(
            [f.attention_mask for f in train_features], dtype=torch.long)
        if args.task == 1:
            print("Excuting the task 1")
        elif args.task == 2:
            all_segment_ids = torch.tensor(
                [f.segment_ids for f in train_features], dtype=torch.long)
        else:
            print("Wrong here2")

        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)

        if args.task == 1:
            train_data = TensorDataset(all_input_ids, all_attention_mask,
                                       all_label_ids)
        elif args.task == 2:
            train_data = TensorDataset(all_input_ids, all_attention_mask,
                                       all_segment_ids, all_label_ids)
        else:
            print("Wrong here1")
        '''
        print("========")
        print(train_data)
        print(type(train_data))
        exit()
        '''

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        output_loss_file = os.path.join(args.output_dir, "loss")
        loss_fout = open(output_loss_file, 'w')
        model.train()

        ##########Pre-Pprocess#########
        ###############################

        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                #batch = tuple(t.to(device) if i != 3 else t for i, t in enumerate(batch))
                batch = tuple(t.to(device) for i, t in enumerate(batch))

                if args.task == 1:
                    input_ids, attention_mask, label_ids = batch
                elif args.task == 2:
                    input_ids, attention_mask, segment_ids, label_ids = batch
                else:
                    print("Wrong here3")

                if args.task == 1:
                    #loss, logits, hidden_states, attentions
                    #output = model(input_ids=input_ids, token_type_ids=None, attention_mask=attention_mask, labels=label_ids)
                    #loss = output.loss
                    loss, logit = model(input_ids_org=input_ids,
                                        token_type_ids=None,
                                        attention_mask=attention_mask,
                                        sentence_label=label_ids,
                                        func="task_class")
                elif args.task == 2:
                    #loss, logits, hidden_states, attentions
                    #output = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=attention_mask, labels=label_ids)
                    #output = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=attention_mask, labels=label_ids)
                    #output = model(input_ids=input_ids, token_type_ids=None, attention_mask=attention_mask, labels=label_ids)
                    #loss = output.loss
                    loss, logit = model(input_ids_org=input_ids,
                                        token_type_ids=None,
                                        attention_mask=attention_mask,
                                        sentence_label=label_ids,
                                        func="task_class")
                else:
                    print("Wrong!!")

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    ###
                    #optimizer.backward(loss)
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    ###
                else:
                    loss.backward()

                loss_fout.write("{}\n".format(loss.item()))
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    ###
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
                    global_step += 1
                    ###
            if epoch < -1:
                continue
            else:
                model_to_save = model.module if hasattr(model,
                                                        'module') else model
                #output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
                output_model_file = os.path.join(
                    args.output_dir, "pytorch_model.bin_{}".format(epoch))
                torch.save(model_to_save.state_dict(), output_model_file)

        # Save a trained model
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
        torch.save(model_to_save.state_dict(), output_model_file)
Beispiel #17
0
def main():
    parser = argparse.ArgumentParser()

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")

    parser.add_argument('--kshot',
                        type=int,
                        default=5,
                        help="random seed for initialization")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--target_train_batch_size",
                        default=2,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--update_BERT_top_layers',
                        type=int,
                        default=1,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    target_kshot_entail_examples, target_kshot_nonentail_examples, target_dev_examples, target_test_examples = load_FewRel_GFS_Entail(
        args.kshot)

    system_seed = 42
    random.seed(system_seed)
    np.random.seed(system_seed)
    torch.manual_seed(system_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(system_seed)

    source_kshot_size = 10  # if args.kshot>10 else 10 if max(10, args.kshot)
    source_kshot_entail, source_kshot_neural, source_kshot_contra, source_remaining_examples = get_MNLI_train(
        '/export/home/Dataset/glue_data/MNLI/train.tsv', source_kshot_size)
    source_examples = source_kshot_entail + source_kshot_neural + source_kshot_contra + source_remaining_examples
    target_label_list = ["entailment", "non_entailment"]
    source_label_list = ["entailment", "neutral", "contradiction"]
    # entity_label_list = ["A-coref", "B-coref"]
    source_num_labels = len(source_label_list)
    target_num_labels = len(target_label_list)
    print('training size:', len(source_examples), 'dev size:',
          len(target_dev_examples), 'test size:', len(target_test_examples))

    roberta_model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    roberta_model.load_state_dict(torch.load(
        '/export/home/Dataset/BERT_pretrained_mine/MNLI_pretrained/_acc_0.9040886899918633.pt'
    ),
                                  strict=False)
    '''
    embedding layer 5 variables
    each bert layer 16 variables
    '''
    param_size = 0
    update_top_layer_size = args.update_BERT_top_layers
    for name, param in roberta_model.named_parameters():
        if param_size < (5 + 16 * (24 - update_top_layer_size)):
            param.requires_grad = False
        param_size += 1
    roberta_model.to(device)

    protonet = PrototypeNet(bert_hidden_dim)
    protonet.to(device)

    param_optimizer = list(protonet.named_parameters()) + list(
        roberta_model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0

    retrieve_batch_size = 5

    source_kshot_entail_dataloader = examples_to_features(
        source_kshot_entail,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_kshot_neural_dataloader = examples_to_features(
        source_kshot_neural,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_kshot_contra_dataloader = examples_to_features(
        source_kshot_contra,
        source_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    source_remain_ex_dataloader = examples_to_features(
        source_remaining_examples,
        source_label_list,
        args,
        tokenizer,
        args.train_batch_size,
        "classification",
        dataloader_mode='random')

    target_kshot_entail_dataloader = examples_to_features(
        target_kshot_entail_examples,
        target_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    target_kshot_nonentail_dataloader = examples_to_features(
        target_kshot_nonentail_examples,
        target_label_list,
        args,
        tokenizer,
        retrieve_batch_size,
        "classification",
        dataloader_mode='sequential')
    target_dev_dataloader = examples_to_features(target_dev_examples,
                                                 target_label_list,
                                                 args,
                                                 tokenizer,
                                                 args.eval_batch_size,
                                                 "classification",
                                                 dataloader_mode='sequential')
    target_test_dataloader = examples_to_features(target_test_examples,
                                                  target_label_list,
                                                  args,
                                                  tokenizer,
                                                  args.eval_batch_size,
                                                  "classification",
                                                  dataloader_mode='sequential')
    '''starting to train'''
    iter_co = 0
    tr_loss = 0
    source_loss = 0
    target_loss = 0
    final_test_performance = 0.0
    for _ in trange(int(args.num_train_epochs), desc="Epoch"):

        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(
                tqdm(source_remain_ex_dataloader, desc="Iteration")):
            protonet.train()
            batch = tuple(t.to(device) for t in batch)
            _, input_ids, input_mask, segment_ids, source_label_ids_batch = batch

            roberta_model.train()
            source_last_hidden_batch, _ = roberta_model(input_ids, input_mask)
            '''
            retrieve rep for support examples in MNLI
            '''
            kshot_entail_reps = torch.zeros(1, bert_hidden_dim).to(device)
            entail_batch_i = 0
            for entail_batch in source_kshot_entail_dataloader:
                roberta_model.train()
                last_hidden_entail, _ = roberta_model(
                    entail_batch[1].to(device), entail_batch[2].to(device))
                kshot_entail_reps += torch.mean(last_hidden_entail,
                                                dim=0,
                                                keepdim=True)
                entail_batch_i += 1
            kshot_entail_rep = kshot_entail_reps / entail_batch_i
            kshot_neural_reps = torch.zeros(1, bert_hidden_dim).to(device)
            neural_batch_i = 0
            for neural_batch in source_kshot_neural_dataloader:
                roberta_model.train()
                last_hidden_neural, _ = roberta_model(
                    neural_batch[1].to(device), neural_batch[2].to(device))
                kshot_neural_reps += torch.mean(last_hidden_neural,
                                                dim=0,
                                                keepdim=True)
                neural_batch_i += 1
            kshot_neural_rep = kshot_neural_reps / neural_batch_i
            kshot_contra_reps = torch.zeros(1, bert_hidden_dim).to(device)
            contra_batch_i = 0
            for contra_batch in source_kshot_contra_dataloader:
                roberta_model.train()
                last_hidden_contra, _ = roberta_model(
                    contra_batch[1].to(device), contra_batch[2].to(device))
                kshot_contra_reps += torch.mean(last_hidden_contra,
                                                dim=0,
                                                keepdim=True)
                contra_batch_i += 1
            kshot_contra_rep = kshot_contra_reps / contra_batch_i

            source_class_prototype_reps = torch.cat(
                [kshot_entail_rep, kshot_neural_rep, kshot_contra_rep],
                dim=0)  #(3, hidden)
            '''first get representations for support examples in target'''
            target_kshot_entail_dataloader_subset = examples_to_features(
                random.sample(target_kshot_entail_examples, 10),
                target_label_list,
                args,
                tokenizer,
                retrieve_batch_size,
                "classification",
                dataloader_mode='sequential')
            target_kshot_nonentail_dataloader_subset = examples_to_features(
                random.sample(target_kshot_nonentail_examples, 10),
                target_label_list,
                args,
                tokenizer,
                retrieve_batch_size,
                "classification",
                dataloader_mode='sequential')
            kshot_entail_reps = []
            for entail_batch in target_kshot_entail_dataloader_subset:
                roberta_model.train()
                last_hidden_entail, _ = roberta_model(
                    entail_batch[1].to(device), entail_batch[2].to(device))
                kshot_entail_reps.append(last_hidden_entail)
            all_kshot_entail_reps = torch.cat(kshot_entail_reps, dim=0)
            kshot_entail_rep = torch.mean(all_kshot_entail_reps,
                                          dim=0,
                                          keepdim=True)
            kshot_nonentail_reps = []
            for nonentail_batch in target_kshot_nonentail_dataloader_subset:
                roberta_model.train()
                last_hidden_nonentail, _ = roberta_model(
                    nonentail_batch[1].to(device),
                    nonentail_batch[2].to(device))
                kshot_nonentail_reps.append(last_hidden_nonentail)
            all_kshot_neural_reps = torch.cat(kshot_nonentail_reps, dim=0)
            kshot_nonentail_rep = torch.mean(all_kshot_neural_reps,
                                             dim=0,
                                             keepdim=True)
            target_class_prototype_reps = torch.cat(
                [kshot_entail_rep, kshot_nonentail_rep, kshot_nonentail_rep],
                dim=0)  #(3, hidden)

            class_prototype_reps = torch.cat(
                [source_class_prototype_reps, target_class_prototype_reps],
                dim=0)  #(6, hidden)
            '''forward to model'''

            target_batch_size = args.target_train_batch_size  #10*3
            # print('target_batch_size:', target_batch_size)
            target_batch_size_entail = target_batch_size  #random.randrange(5)+1
            target_batch_size_neural = target_batch_size  #random.randrange(5)+1

            selected_target_entail_rep = all_kshot_entail_reps[torch.randperm(
                all_kshot_entail_reps.shape[0])[:target_batch_size_entail]]
            # print('selected_target_entail_rep:', selected_target_entail_rep.shape)
            selected_target_neural_rep = all_kshot_neural_reps[torch.randperm(
                all_kshot_neural_reps.shape[0])[:target_batch_size_neural]]
            # print('selected_target_neural_rep:', selected_target_neural_rep.shape)
            target_last_hidden_batch = torch.cat(
                [selected_target_entail_rep, selected_target_neural_rep])

            last_hidden_batch = torch.cat(
                [source_last_hidden_batch, target_last_hidden_batch],
                dim=0)  #(train_batch_size+10*2)
            # print('last_hidden_batch shape:', last_hidden_batch.shape)
            batch_logits = protonet(class_prototype_reps, last_hidden_batch)
            # exit(0)
            '''source side loss'''
            # loss_fct = CrossEntropyLoss(reduction='none')
            loss_fct = CrossEntropyLoss()
            source_loss_list = loss_fct(
                batch_logits[:source_last_hidden_batch.shape[0]].view(
                    -1, source_num_labels), source_label_ids_batch.view(-1))
            '''target side loss'''
            target_label_ids_batch = torch.tensor(
                [0] * selected_target_entail_rep.shape[0] +
                [1] * selected_target_neural_rep.shape[0],
                dtype=torch.long)
            target_batch_logits = batch_logits[-target_last_hidden_batch.
                                               shape[0]:]
            target_loss_list = loss_by_logits_and_2way_labels(
                target_batch_logits, target_label_ids_batch.view(-1), device)
            # target_loss_list = loss_fct(target_batch_logits.view(-1, source_num_labels), target_label_ids_batch.to(device).view(-1))
            loss = source_loss_list + target_loss_list  #torch.mean(torch.cat([source_loss_list, target_loss_list]))
            source_loss += source_loss_list
            target_loss += target_loss_list
            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            global_step += 1
            iter_co += 1
            # print('iter_co:', iter_co, 'mean loss:', tr_loss/iter_co)
            if iter_co % 20 == 0:
                # if iter_co % len(source_remain_ex_dataloader)==0:
                '''
                start evaluate on dev set after this epoch
                '''
                '''
                retrieve rep for support examples in MNLI
                '''
                kshot_entail_reps = torch.zeros(1, bert_hidden_dim).to(device)
                entail_batch_i = 0
                for entail_batch in source_kshot_entail_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_entail, _ = roberta_model(
                            entail_batch[1].to(device),
                            entail_batch[2].to(device))
                    kshot_entail_reps += torch.mean(last_hidden_entail,
                                                    dim=0,
                                                    keepdim=True)
                    entail_batch_i += 1
                kshot_entail_rep = kshot_entail_reps / entail_batch_i
                kshot_neural_reps = torch.zeros(1, bert_hidden_dim).to(device)
                neural_batch_i = 0
                for neural_batch in source_kshot_neural_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_neural, _ = roberta_model(
                            neural_batch[1].to(device),
                            neural_batch[2].to(device))
                    kshot_neural_reps += torch.mean(last_hidden_neural,
                                                    dim=0,
                                                    keepdim=True)
                    neural_batch_i += 1
                kshot_neural_rep = kshot_neural_reps / neural_batch_i
                kshot_contra_reps = torch.zeros(1, bert_hidden_dim).to(device)
                contra_batch_i = 0
                for contra_batch in source_kshot_contra_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_contra, _ = roberta_model(
                            contra_batch[1].to(device),
                            contra_batch[2].to(device))
                    kshot_contra_reps += torch.mean(last_hidden_contra,
                                                    dim=0,
                                                    keepdim=True)
                    contra_batch_i += 1
                kshot_contra_rep = kshot_contra_reps / contra_batch_i

                source_class_prototype_reps = torch.cat(
                    [kshot_entail_rep, kshot_neural_rep, kshot_contra_rep],
                    dim=0)  #(3, hidden)
                '''first get representations for support examples in target'''
                # target_kshot_entail_dataloader_subset = examples_to_features(random.sample(target_kshot_entail_examples, args.kshot), target_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
                # target_kshot_nonentail_dataloader_subset = examples_to_features(random.sample(target_kshot_nonentail_examples, args.kshot), target_label_list, args, tokenizer, retrieve_batch_size, "classification", dataloader_mode='sequential')
                kshot_entail_reps = torch.zeros(1, bert_hidden_dim).to(device)
                entail_batch_i = 0
                for entail_batch in target_kshot_entail_dataloader_subset:  #target_kshot_entail_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_entail, _ = roberta_model(
                            entail_batch[1].to(device),
                            entail_batch[2].to(device))
                    kshot_entail_reps += torch.mean(last_hidden_entail,
                                                    dim=0,
                                                    keepdim=True)
                    entail_batch_i += 1
                kshot_entail_rep = kshot_entail_reps / entail_batch_i
                kshot_nonentail_reps = torch.zeros(1,
                                                   bert_hidden_dim).to(device)
                nonentail_batch_i = 0
                for nonentail_batch in target_kshot_nonentail_dataloader_subset:  #target_kshot_nonentail_dataloader:
                    roberta_model.eval()
                    with torch.no_grad():
                        last_hidden_nonentail, _ = roberta_model(
                            nonentail_batch[1].to(device),
                            nonentail_batch[2].to(device))
                    kshot_nonentail_reps += torch.mean(last_hidden_nonentail,
                                                       dim=0,
                                                       keepdim=True)
                    nonentail_batch_i += 1
                kshot_nonentail_rep = kshot_nonentail_reps / nonentail_batch_i
                target_class_prototype_reps = torch.cat([
                    kshot_entail_rep, kshot_nonentail_rep, kshot_nonentail_rep
                ],
                                                        dim=0)  #(3, hidden)

                class_prototype_reps = torch.cat(
                    [source_class_prototype_reps, target_class_prototype_reps],
                    dim=0)  #(6, hidden)

                protonet.eval()

                # dev_acc = evaluation(protonet, target_dev_dataloader,  device, flag='Dev')
                # print('class_prototype_reps:', class_prototype_reps)
                dev_acc = evaluation(protonet,
                                     roberta_model,
                                     class_prototype_reps,
                                     target_dev_dataloader,
                                     device,
                                     flag='Dev')
                if dev_acc > max_dev_acc:
                    max_dev_acc = dev_acc
                    print('\n\t dev acc:', dev_acc, ' max_dev_acc:',
                          max_dev_acc, '\n')
                    if dev_acc > 0.73:  #10:0.73; 5:0.66
                        test_acc = evaluation(protonet,
                                              roberta_model,
                                              class_prototype_reps,
                                              target_test_dataloader,
                                              device,
                                              flag='Test')
                        if test_acc > max_test_acc:
                            max_test_acc = test_acc

                        final_test_performance = test_acc
                        print('\n\t test acc:', test_acc, ' max_test_acc:',
                              max_test_acc, '\n')
                else:
                    print('\n\t dev acc:', dev_acc, ' max_dev_acc:',
                          max_dev_acc, '\n')

            if iter_co == 2000:
                break
    print('final_test_performance:', final_test_performance)
Beispiel #18
0
def main(args):
    local_config = json.load(open(args.local_config_path))
    local_config['loss'] = args.loss
    local_config['data_dir'] = args.data_dir
    local_config['train_batch_size'] = args.train_batch_size
    local_config[
        'gradient_accumulation_steps'] = args.gradient_accumulation_steps
    local_config['lr_scheduler'] = args.lr_scheduler
    local_config['model_name'] = args.model_name
    local_config['pool_type'] = args.pool_type
    local_config['seed'] = args.seed
    local_config['do_train'] = args.do_train
    local_config['do_validation'] = args.do_validation
    local_config['do_eval'] = args.do_eval
    local_config['use_cuda'] = args.use_cuda.lower() == 'true'
    local_config['num_train_epochs'] = args.num_train_epochs
    local_config['eval_batch_size'] = args.eval_batch_size
    local_config['max_seq_len'] = args.max_seq_len
    local_config['syns'] = ["Target", "Synonym"]
    local_config['target_embeddings'] = args.target_embeddings
    local_config['symmetric'] = args.symmetric.lower() == 'true'
    local_config['mask_syns'] = args.mask_syns
    local_config['train_scd'] = args.train_scd
    local_config['ckpt_path'] = args.ckpt_path
    local_config['head_batchnorm'] = args.head_batchnorm
    local_config['head_hidden_size'] = args.head_hidden_size
    local_config['linear_head'] = args.linear_head.lower() == 'true'
    local_config['emb_size_for_cosine'] = args.emb_size_for_cosine
    local_config['add_fc_layer'] = args.add_fc_layer

    if local_config['do_train'] and os.path.exists(args.output_dir):
        from glob import glob
        model_weights = glob(os.path.join(args.output_dir, '*.bin'))
        if model_weights:
            print(f'{model_weights}: already computed: skipping ...')
            return
        else:
            print(
                f'already existing {args.output_dir}. but without model weights ...'
            )
            return

    device = torch.device("cuda" if local_config['use_cuda'] else "cpu")
    n_gpu = torch.cuda.device_count()

    if local_config['gradient_accumulation_steps'] < 1:
        raise ValueError(
            "gradient_accumulation_steps parameter should be >= 1")

    local_config['train_batch_size'] = \
        local_config['train_batch_size'] // local_config['gradient_accumulation_steps']

    if local_config['do_train']:
        random.seed(local_config['seed'])
        np.random.seed(local_config['seed'])
        torch.manual_seed(local_config['seed'])

    if n_gpu > 0:
        torch.cuda.manual_seed_all(local_config['seed'])

    if not local_config['do_train'] and not local_config['do_eval']:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if local_config['do_train'] and not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
        os.makedirs(os.path.join(args.output_dir, 'nen-nen-weights'))
    elif local_config['do_train'] or local_config['do_validation']:
        raise ValueError(args.output_dir, 'output_dir already exists')

    suffix = datetime.now().isoformat().replace('-', '_').replace(
        ':', '_').split('.')[0].replace('T', '-')
    if local_config['do_train']:
        train_writer = SummaryWriter(log_dir=os.path.join(
            args.output_dir, f'tensorboard-{suffix}', 'train'))
        dev_writer = SummaryWriter(log_dir=os.path.join(
            args.output_dir, f'tensorboard-{suffix}', 'dev'))

        logger.addHandler(
            logging.FileHandler(
                os.path.join(args.output_dir, f"train_{suffix}.log"), 'w'))
        eval_logger.addHandler(
            logging.FileHandler(
                os.path.join(args.output_dir, f"scores_{suffix}.log"), 'w'))
    else:
        logger.addHandler(
            logging.FileHandler(
                os.path.join(args.ckpt_path, f"eval_{suffix}.log"), 'w'))

    logger.info(args)
    logger.info(json.dumps(vars(args), indent=4))
    if args.do_train:
        json.dump(
            local_config,
            open(os.path.join(args.output_dir, 'local_config.json'), 'w'))
        json.dump(vars(args),
                  open(os.path.join(args.output_dir, 'args.json'), 'w'))
    logger.info("device: {}, n_gpu: {}".format(device, n_gpu))

    with open(os.path.join(args.output_dir, 'local_config.json'), 'w') as outp:
        json.dump(local_config, outp, indent=4)
    with open(os.path.join(args.output_dir, 'args.json'), 'w') as outp:
        json.dump(vars(args), outp, indent=4)

    syns = sorted(local_config['syns'])
    id2classifier = {i: classifier for i, classifier in enumerate(syns)}

    model_name = local_config['model_name']
    data_processor = DataProcessor()

    train_dir = os.path.join(local_config['data_dir'], 'train/')
    dev_dir = os.path.join(local_config['data_dir'], 'dev')

    if local_config['do_train']:

        config = configs[local_config['model_name']]
        config = config.from_pretrained(local_config['model_name'],
                                        hidden_dropout_prob=args.dropout)
        if args.ckpt_path != '':
            model_path = args.ckpt_path
        else:
            model_path = local_config['model_name']
        model = models[model_name].from_pretrained(
            model_path,
            cache_dir=str(PYTORCH_PRETRAINED_BERT_CACHE),
            local_config=local_config,
            data_processor=data_processor,
            config=config)

        param_optimizer = list(model.named_parameters())

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                param for name, param in param_optimizer
                if not any(nd in name for nd in no_decay)
            ],
            'weight_decay':
            float(args.weight_decay)
        }, {
            'params': [
                param for name, param in param_optimizer
                if any(nd in name for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=float(args.learning_rate),
                          eps=1e-6,
                          betas=(0.9, 0.98),
                          correct_bias=True)

        train_features = model.convert_dataset_to_features(train_dir, logger)

        if args.train_mode == 'sorted' or args.train_mode == 'random_sorted':
            train_features = sorted(train_features,
                                    key=lambda f: np.sum(f.input_mask))
        else:
            random.shuffle(train_features)


#        import pdb; pdb.set_trace()
        train_dataloader = \
            get_dataloader_and_tensors(train_features, local_config['train_batch_size'])
        train_batches = [batch for batch in train_dataloader]

        num_train_optimization_steps = \
            len(train_batches) // local_config['gradient_accumulation_steps'] * \
                local_config['num_train_epochs']

        warmup_steps = int(args.warmup_proportion *
                           num_train_optimization_steps)
        if local_config['lr_scheduler'] == 'linear_warmup':
            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=num_train_optimization_steps)
        elif local_config['lr_scheduler'] == 'constant_warmup':
            scheduler = get_constant_schedule_with_warmup(
                optimizer, num_warmup_steps=warmup_steps)
        logger.info("***** Training *****")
        logger.info("  Num examples = %d", len(train_features))
        logger.info("  Batch size = %d", local_config['train_batch_size'])
        logger.info("  Num steps = %d", num_train_optimization_steps)

        if local_config['do_validation']:
            dev_features = model.convert_dataset_to_features(dev_dir, logger)
            logger.info("***** Dev *****")
            logger.info("  Num examples = %d", len(dev_features))
            logger.info("  Batch size = %d", local_config['eval_batch_size'])
            dev_dataloader = \
                get_dataloader_and_tensors(dev_features, local_config['eval_batch_size'])
            test_dir = os.path.join(local_config['data_dir'], 'test/')
            if os.path.exists(test_dir):
                test_features = model.convert_dataset_to_features(
                    test_dir, test_logger)
                logger.info("***** Test *****")
                logger.info("  Num examples = %d", len(test_features))
                logger.info("  Batch size = %d",
                            local_config['eval_batch_size'])

                test_dataloader = \
                    get_dataloader_and_tensors(test_features, local_config['eval_batch_size'])

        best_result = defaultdict(float)

        eval_step = max(1, len(train_batches) // args.eval_per_epoch)

        start_time = time.time()
        global_step = 0

        model.to(device)
        lr = float(args.learning_rate)
        for epoch in range(1, 1 + local_config['num_train_epochs']):
            tr_loss = 0
            nb_tr_examples = 0
            nb_tr_steps = 0
            cur_train_loss = defaultdict(float)

            model.train()
            logger.info("Start epoch #{} (lr = {})...".format(
                epoch,
                scheduler.get_lr()[0]))
            if args.train_mode == 'random' or args.train_mode == 'random_sorted':
                random.shuffle(train_batches)

            train_bar = tqdm(train_batches,
                             total=len(train_batches),
                             desc='training ... ')
            for step, batch in enumerate(train_bar):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, token_type_ids, \
                syn_labels, positions = batch
                train_loss, _ = model(input_ids=input_ids,
                                      token_type_ids=token_type_ids,
                                      attention_mask=input_mask,
                                      input_labels={
                                          'syn_labels': syn_labels,
                                          'positions': positions
                                      })
                loss = train_loss['total'].mean().item()
                for key in train_loss:
                    cur_train_loss[key] += train_loss[key].mean().item()

                train_bar.set_description(
                    f'training... [epoch == {epoch} / {local_config["num_train_epochs"]}, loss == {loss}]'
                )

                loss_to_optimize = train_loss['total']

                if local_config['gradient_accumulation_steps'] > 1:
                    loss_to_optimize = \
                        loss_to_optimize / local_config['gradient_accumulation_steps']

                loss_to_optimize.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

                tr_loss += loss_to_optimize.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                if (step +
                        1) % local_config['gradient_accumulation_steps'] == 0:
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1

                if local_config['do_validation'] and (step +
                                                      1) % eval_step == 0:
                    logger.info(
                        'Ep: {}, Stp: {}/{}, usd_t={:.2f}s, loss={:.6f}'.
                        format(epoch, step + 1, len(train_batches),
                               time.time() - start_time,
                               tr_loss / nb_tr_steps))

                    cur_train_mean_loss = {}
                    for key, value in cur_train_loss.items():
                        cur_train_mean_loss[f'train_{key}_loss'] = \
                            value / nb_tr_steps

                    dev_predictions = os.path.join(args.output_dir,
                                                   'dev_predictions')

                    metrics = predict(model,
                                      dev_dataloader,
                                      dev_predictions,
                                      dev_features,
                                      args,
                                      cur_train_mean_loss=cur_train_mean_loss,
                                      logger=eval_logger)

                    metrics['global_step'] = global_step
                    metrics['epoch'] = epoch
                    metrics['learning_rate'] = scheduler.get_lr()[0]
                    metrics['batch_size'] = \
                        local_config['train_batch_size'] * local_config['gradient_accumulation_steps']

                    for key, value in metrics.items():
                        dev_writer.add_scalar(key, value, global_step)
                    scores_to_logger = tuple([
                        round(metrics[save_by_score] * 100.0, 2)
                        for save_by_score in args.save_by_score.split('+')
                    ])
                    logger.info(
                        f"dev %s (lr=%s, epoch=%d): %s" %
                        (args.save_by_score, str(
                            scheduler.get_lr()[0]), epoch, scores_to_logger))

                    predict_parts = [
                        part for part in metrics if part.endswith('.score')
                        and metrics[part] > args.start_save_threshold
                        and metrics[part] > best_result[part]
                    ]
                    if len(predict_parts) > 0:
                        best_dev_predictions = os.path.join(
                            args.output_dir, 'best_dev_predictions')
                        dev_predictions = os.path.join(args.output_dir,
                                                       'dev_predictions')
                        os.makedirs(best_dev_predictions, exist_ok=True)
                        for part in predict_parts:
                            logger.info(
                                "!!! Best dev %s (lr=%s, epoch=%d): %.2f -> %.2f"
                                % (part, str(scheduler.get_lr()[0]), epoch,
                                   best_result[part] * 100.0,
                                   metrics[part] * 100.0))
                            best_result[part] = metrics[part]
                            if [
                                    save_weight for save_weight in
                                    args.save_by_score.split('+')
                                    if save_weight == part
                            ]:
                                os.makedirs(os.path.join(
                                    args.output_dir, part),
                                            exist_ok=True)
                                output_model_file = os.path.join(
                                    args.output_dir, part, WEIGHTS_NAME)
                                save_model(args, model, output_model_file,
                                           metrics)
                            if 'nen-nen' not in part:
                                os.system(
                                    f'cp {dev_predictions}/{".".join(part.split(".")[1:-1])}* {best_dev_predictions}/'
                                )
                            else:
                                output_model_file = os.path.join(
                                    args.output_dir, 'nen-nen-weights',
                                    WEIGHTS_NAME)
                                save_model(args, model, output_model_file,
                                           metrics)

                        # dev_predictions = os.path.join(args.output_dir, 'dev_predictions')
                        # predict(
                        #     model, dev_dataloader, dev_predictions,
                        #     dev_features, args, only_parts='+'.join(predict_parts)
                        # )
                        # best_dev_predictions = os.path.join(args.output_dir, 'best_dev_predictions')
                        # os.makedirs(best_dev_predictions, exist_ok=True)
                        # os.system(f'mv {dev_predictions}/* {best_dev_predictions}/')
                        if 'scd' not in '+'.join(
                                predict_parts) and os.path.exists(test_dir):
                            test_predictions = os.path.join(
                                args.output_dir, 'test_predictions')
                            test_metrics = predict(
                                model,
                                test_dataloader,
                                test_predictions,
                                test_features,
                                args,
                                only_parts='+'.join([
                                    'test' + part[3:] for part in predict_parts
                                    if 'nen-nen' not in part
                                ]))
                            best_test_predictions = os.path.join(
                                args.output_dir, 'best_test_predictions')
                            os.makedirs(best_test_predictions, exist_ok=True)
                            os.system(
                                f'mv {test_predictions}/* {best_test_predictions}/'
                            )

                            for key, value in test_metrics.items():
                                if key.endswith('score'):
                                    dev_writer.add_scalar(
                                        key, value, global_step)

            if args.log_train_metrics:
                metrics = predict(model,
                                  train_dataloader,
                                  os.path.join(args.output_dir,
                                               'train_predictions'),
                                  train_features,
                                  args,
                                  logger=logger)
                metrics['global_step'] = global_step
                metrics['epoch'] = epoch
                metrics['learning_rate'] = scheduler.get_lr()[0]
                metrics['batch_size'] = \
                    local_config['train_batch_size'] * local_config['gradient_accumulation_steps']

                for key, value in metrics.items():
                    train_writer.add_scalar(key, value, global_step)

    if local_config['do_eval']:
        assert args.ckpt_path != '', 'in do_eval mode ckpt_path should be specified'
        test_dir = args.eval_input_dir
        config = configs[model_name].from_pretrained(model_name)
        model = models[model_name].from_pretrained(
            args.ckpt_path,
            local_config=local_config,
            data_processor=data_processor,
            config=config)
        model.to(device)
        test_features = model.convert_dataset_to_features(
            test_dir, test_logger)
        logger.info("***** Test *****")
        logger.info("  Num examples = %d", len(test_features))
        logger.info("  Batch size = %d", local_config['eval_batch_size'])

        test_dataloader = \
            get_dataloader_and_tensors(test_features, local_config['eval_batch_size'])

        metrics = predict(model,
                          test_dataloader,
                          os.path.join(args.output_dir, args.eval_output_dir),
                          test_features,
                          args,
                          compute_metrics=True)
        print(metrics)
        with open(
                os.path.join(args.output_dir, args.eval_output_dir,
                             'metrics.txt'), 'w') as outp:
            print(metrics, file=outp)
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")

    parser.add_argument(
        "--data_label",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")

    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    train_examples, _ = get_FEVER_examples('train', hypo_only=False)
    dev_and_test_examples, _ = get_FEVER_examples('dev', hypo_only=False)
    random.shuffle(dev_and_test_examples)
    dev_examples = dev_and_test_examples[:-10000]
    test_examples = dev_and_test_examples[-10000:]
    '''write into files'''
    def examples_2_file(exs, prefix):
        writefile = codecs.open(
            '/export/home/Dataset/para_entail_datasets/nli_FEVER/nli_fever/my_split_binary/'
            + prefix + '.txt', 'w', 'utf-8')
        for ex in exs:
            writefile.write(ex.label + '\t' + ex.text_a + '\t' + ex.text_b +
                            '\n')
        print('print over')
        writefile.close()

    examples_2_file(train_examples, 'train')
    examples_2_file(dev_examples, 'dev')
    examples_2_file(test_examples, 'test')
    exit(0)

    label_list = ["entailment", "not_entailment"]  #, "contradiction"]
    num_labels = len(label_list)
    print('num_labels:', num_labels, 'training size:', len(train_examples),
          'dev size:', len(dev_examples), ' test size:', len(test_examples))

    num_train_optimization_steps = None
    num_train_optimization_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps) * args.num_train_epochs
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    model = RobertaForSequenceClassification(num_labels)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.load_state_dict(
        torch.load(
            '/export/home/Dataset/BERT_pretrained_mine/paragraph_entail/2021/ANLI_CNNDailyMail_DUC_Curation_SQUAD_epoch_1.pt',
            map_location=device))
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)
        '''load dev set'''
        dev_features = convert_examples_to_features(
            dev_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        dev_all_input_ids = torch.tensor([f.input_ids for f in dev_features],
                                         dtype=torch.long)
        dev_all_input_mask = torch.tensor([f.input_mask for f in dev_features],
                                          dtype=torch.long)
        dev_all_segment_ids = torch.tensor(
            [f.segment_ids for f in dev_features], dtype=torch.long)
        dev_all_label_ids = torch.tensor([f.label_id for f in dev_features],
                                         dtype=torch.long)

        dev_data = TensorDataset(dev_all_input_ids, dev_all_input_mask,
                                 dev_all_segment_ids, dev_all_label_ids)
        dev_sampler = SequentialSampler(dev_data)
        dev_dataloader = DataLoader(dev_data,
                                    sampler=dev_sampler,
                                    batch_size=args.eval_batch_size)
        '''load test set'''
        test_features = convert_examples_to_features(
            test_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        test_all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                          dtype=torch.long)
        test_all_input_mask = torch.tensor(
            [f.input_mask for f in test_features], dtype=torch.long)
        test_all_segment_ids = torch.tensor(
            [f.segment_ids for f in test_features], dtype=torch.long)
        test_all_label_ids = torch.tensor([f.label_id for f in test_features],
                                          dtype=torch.long)

        test_data = TensorDataset(test_all_input_ids, test_all_input_mask,
                                  test_all_segment_ids, test_all_label_ids)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.eval_batch_size)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        train_sampler = RandomSampler(train_data)

        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        iter_co = 0
        final_test_performance = 0.0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                logits = model(input_ids, input_mask)
                loss_fct = CrossEntropyLoss()

                loss = loss_fct(logits.view(-1, num_labels),
                                label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                iter_co += 1
            '''
            start evaluate on dev set after this epoch
            '''
            model.eval()

            dev_acc = evaluation(dev_dataloader, device, model)

            if dev_acc > max_dev_acc:
                max_dev_acc = dev_acc
                print('\ndev acc:', dev_acc, ' max_dev_acc:', max_dev_acc,
                      '\n')
                '''evaluate on the test set with the best dev model'''
                final_test_performance = evaluation(test_dataloader, device,
                                                    model)
                print('\ntest acc:', final_test_performance, '\n')

            else:
                print('\ndev acc:', dev_acc, ' max_dev_acc:', max_dev_acc,
                      '\n')
        print('final_test_performance:', final_test_performance)
Beispiel #20
0
class TRADE(nn.Module):
    def __init__(self,
                 hidden_size,
                 lang,
                 path,
                 task,
                 lr,
                 dropout,
                 slots,
                 gating_dict,
                 t_total,
                 device,
                 nb_train_vocab=0):
        super(TRADE, self).__init__()
        self.name = "TRADE"
        self.task = task
        self.hidden_size = hidden_size
        self.lang = lang[0]
        self.mem_lang = lang[1]
        self.lr = lr
        self.dropout = dropout
        self.slots = slots[0]
        self.slot_temp = slots[2]
        self.gating_dict = gating_dict
        self.device = device
        self.nb_gate = len(gating_dict)
        self.cross_entorpy = nn.CrossEntropyLoss()
        self.cell_type = args['cell_type']

        if args['encoder'] == 'RNN':
            self.encoder = EncoderRNN(self.lang.n_words, hidden_size,
                                      self.dropout, self.device,
                                      self.cell_type)
            self.decoder = Generator(self.lang, self.encoder.embedding,
                                     self.lang.n_words, hidden_size,
                                     self.dropout, self.slots, self.nb_gate,
                                     self.device, self.cell_type)
        elif args['encoder'] == 'TPRNN':
            self.encoder = EncoderTPRNN(self.lang.n_words, hidden_size,
                                        self.dropout, self.device,
                                        self.cell_type, args['nSymbols'],
                                        args['nRoles'], args['dSymbols'],
                                        args['dRoles'], args['temperature'],
                                        args['scale_val'], args['train_scale'])
            self.decoder = Generator(self.lang, self.encoder.embedding,
                                     self.lang.n_words, hidden_size,
                                     self.dropout, self.slots, self.nb_gate,
                                     self.device, self.cell_type)
        else:
            self.encoder = BERTEncoder(hidden_size, self.dropout, self.device)
            self.decoder = Generator(self.lang, None, self.lang.n_words,
                                     hidden_size, self.dropout, self.slots,
                                     self.nb_gate, self.device, self.cell_type)

        if path:
            print("MODEL {} LOADED".format(str(path)))
            trained_encoder = torch.load(str(path) + '/enc.th',
                                         map_location=self.device)
            trained_decoder = torch.load(str(path) + '/dec.th',
                                         map_location=self.device)

            # fix small confusion between old and newer trained models
            encoder_dict = trained_encoder.state_dict()
            new_encoder_dict = {}
            for key in encoder_dict:
                mapped_key = key
                if key.startswith('gru.'):
                    mapped_key = 'rnn.' + key[len('gru.'):]
                new_encoder_dict[mapped_key] = encoder_dict[key]

            decoder_dict = trained_decoder.state_dict()
            new_decoder_dict = {}
            for key in decoder_dict:
                mapped_key = key
                if key.startswith('gru.'):
                    mapped_key = 'rnn.' + key[len('gru.'):]
                new_decoder_dict[mapped_key] = decoder_dict[key]

            if not 'W_slot_embed.weight' in new_decoder_dict:
                new_decoder_dict['W_slot_embed.weight'] = torch.zeros(
                    (hidden_size, 2 * hidden_size), requires_grad=False)
                new_decoder_dict['W_slot_embed.bias'] = torch.zeros(
                    (hidden_size, ), requires_grad=False)

            self.encoder.load_state_dict(new_encoder_dict)
            self.decoder.load_state_dict(new_decoder_dict)

        # Initialize optimizers and criterion
        if args['encoder'] == 'RNN':
            self.optimizer = optim.Adam(self.parameters(), lr=lr)
            self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                            mode='max',
                                                            factor=0.5,
                                                            patience=1,
                                                            min_lr=0.0001,
                                                            verbose=True)
        else:
            if args['local_rank'] != -1:
                t_total = t_total // torch.distributed.get_world_size()

            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in self.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.01
            }, {
                'params': [
                    p for n, p in self.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]
            self.optimizer = AdamW(optimizer_grouped_parameters,
                                   lr=args['learn'],
                                   correct_bias=False)
            self.scheduler = WarmupLinearSchedule(
                self.optimizer,
                warmup_steps=args['warmup_proportion'] * t_total,
                t_total=t_total)

        self.reset()

    def print_loss(self):
        print_loss_avg = self.loss / self.print_every
        print_loss_ptr = self.loss_ptr / self.print_every
        print_loss_gate = self.loss_gate / self.print_every
        print_loss_class = self.loss_class / self.print_every
        # print_loss_domain = self.loss_domain / self.print_every
        self.print_every += 1
        return 'L:{:.2f},LP:{:.2f},LG:{:.2f}'.format(print_loss_avg,
                                                     print_loss_ptr,
                                                     print_loss_gate)

    def save_model(self, dec_type):
        directory = 'save/TRADE-' + args["addName"] + args['dataset'] + str(
            self.task) + '/' + 'HDD' + str(self.hidden_size) + 'BSZ' + str(
                args['batch']) + 'DR' + str(self.dropout) + str(dec_type)
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save(self.encoder, directory + '/enc.th')
        torch.save(self.decoder, directory + '/dec.th')

    def reset(self):
        self.loss, self.print_every, self.loss_ptr, self.loss_gate, self.loss_class = 0, 1, 0, 0, 0

    def forward(self, data, clip, slot_temp, reset=0, n_gpu=0):
        if reset: self.reset()
        # Zero gradients of both optimizers
        self.optimizer.zero_grad()

        # Encode and Decode
        use_teacher_forcing = random.random() < args["teacher_forcing_ratio"]
        all_point_outputs, gates, words_point_out, words_class_out = self.encode_and_decode(
            data, use_teacher_forcing, slot_temp)

        loss_ptr = masked_cross_entropy_for_value(
            all_point_outputs.transpose(0, 1).contiguous(),
            data["generate_y"].contiguous(
            ),  #[:,:len(self.point_slots)].contiguous(),
            data["y_lengths"])  #[:,:len(self.point_slots)])
        loss_gate = self.cross_entorpy(
            gates.transpose(0, 1).contiguous().view(-1, gates.size(-1)),
            data["gating_label"].contiguous().view(-1))

        if args["use_gate"]:
            loss = loss_ptr + loss_gate
        else:
            loss = loss_ptr

        self.loss_grad = loss
        self.loss_ptr_to_bp = loss_ptr

        # Update parameters with optimizers
        self.loss += loss.item()
        self.loss_ptr += loss_ptr.item()
        self.loss_gate += loss_gate.item()

        return self.loss_grad

    def optimize_GEM(self, clip):
        torch.nn.utils.clip_grad_norm_(self.parameters(), clip)
        self.optimizer.step()
        if isinstance(self.scheduler, WarmupLinearSchedule):
            self.scheduler.step()

    def encode_and_decode(self, data, use_teacher_forcing, slot_temp):
        if args['encoder'] == 'RNN' or args['encoder'] == 'TPRNN':
            # Build unknown mask for memory to encourage generalization
            if args['unk_mask'] and self.decoder.training:
                story_size = data['context'].size()
                rand_mask = np.ones(story_size)
                bi_mask = np.random.binomial(
                    [np.ones(
                        (story_size[0], story_size[1]))], 1 - self.dropout)[0]
                rand_mask = rand_mask * bi_mask
                rand_mask = torch.Tensor(rand_mask).to(self.device)
                story = data['context'] * rand_mask.long()
            else:
                story = data['context']

            story = story.to(self.device)
            # encoded_outputs, encoded_hidden = self.encoder(story.transpose(0, 1), data['context_len'])
            encoded_outputs, encoded_hidden = self.encoder(
                story, data['context_len'])

        # Encode dialog history
        # story  32 396
        # data['context_len'] 32
        elif args['encoder'] == 'BERT':
            # import pdb; pdb.set_trace()
            story = data['context']
            # story_plain = data['context_plain']

            all_input_ids = data['all_input_ids']
            all_input_mask = data['all_input_mask']
            all_segment_ids = data['all_segment_ids']
            all_sub_word_masks = data['all_sub_word_masks']

            encoded_outputs, encoded_hidden = self.encoder(
                all_input_ids, all_input_mask, all_segment_ids,
                all_sub_word_masks)
            encoded_hidden = encoded_hidden.unsqueeze(0)

        # Get the words that can be copied from the memory
        # import pdb; pdb.set_trace()
        batch_size = len(data['context_len'])
        self.copy_list = data['context_plain']
        max_res_len = data['generate_y'].size(
            2) if self.encoder.training else 10

        all_point_outputs, all_gate_outputs, words_point_out, words_class_out = self.decoder.forward(batch_size, \
            encoded_hidden, encoded_outputs, data['context_len'], story, max_res_len, data['generate_y'], \
            use_teacher_forcing, slot_temp)

        return all_point_outputs, all_gate_outputs, words_point_out, words_class_out

    def evaluate(self,
                 dev,
                 matric_best,
                 slot_temp,
                 device,
                 save_dir="",
                 save_string="",
                 early_stop=None):
        # Set to not-training mode to disable dropout
        self.encoder.train(False)
        self.decoder.train(False)
        print("STARTING EVALUATION")
        all_prediction = {}
        inverse_unpoint_slot = dict([(v, k)
                                     for k, v in self.gating_dict.items()])
        pbar = enumerate(dev)
        for j, data_dev in pbar:
            # Encode and Decode
            eval_data = {}
            # wrap all numerical values as tensors for multi-gpu training
            for k, v in data_dev.items():
                if isinstance(v, torch.Tensor):
                    eval_data[k] = v.to(device)
                elif isinstance(v, list):
                    if k in [
                            'ID', 'turn_belief', 'context_plain',
                            'turn_uttr_plain'
                    ]:
                        eval_data[k] = v
                    else:
                        eval_data[k] = torch.tensor(v).to(device)
                else:
                    # print('v is: {} and this ignoring {}'.format(v, k))
                    pass
            batch_size = len(data_dev['context_len'])
            with torch.no_grad():
                _, gates, words, class_words = self.encode_and_decode(
                    eval_data, False, slot_temp)

            for bi in range(batch_size):
                if data_dev["ID"][bi] not in all_prediction.keys():
                    all_prediction[data_dev["ID"][bi]] = {}
                all_prediction[data_dev["ID"][bi]][data_dev["turn_id"][bi]] = {
                    "turn_belief": data_dev["turn_belief"][bi]
                }
                predict_belief_bsz_ptr, predict_belief_bsz_class = [], []
                gate = torch.argmax(gates.transpose(0, 1)[bi], dim=1)
                # import pdb; pdb.set_trace()

                # pointer-generator results
                if args["use_gate"]:
                    for si, sg in enumerate(gate):
                        if sg == self.gating_dict["none"]:
                            continue
                        elif sg == self.gating_dict["ptr"]:
                            pred = np.transpose(words[si])[bi]
                            st = []
                            for e in pred:
                                if e == 'EOS': break
                                else: st.append(e)
                            st = " ".join(st)
                            if st == "none":
                                continue
                            else:
                                predict_belief_bsz_ptr.append(slot_temp[si] +
                                                              "-" + str(st))
                        else:
                            predict_belief_bsz_ptr.append(
                                slot_temp[si] + "-" +
                                inverse_unpoint_slot[sg.item()])
                else:
                    for si, _ in enumerate(gate):
                        pred = np.transpose(words[si])[bi]
                        st = []
                        for e in pred:
                            if e == 'EOS': break
                            else: st.append(e)
                        st = " ".join(st)
                        if st == "none":
                            continue
                        else:
                            predict_belief_bsz_ptr.append(slot_temp[si] + "-" +
                                                          str(st))

                all_prediction[data_dev["ID"][bi]][data_dev["turn_id"][bi]][
                    "pred_bs_ptr"] = predict_belief_bsz_ptr

                #if set(data_dev["turn_belief"][bi]) != set(predict_belief_bsz_ptr) and args["genSample"]:
                #    print("True", set(data_dev["turn_belief"][bi]) )
                #    print("Pred", set(predict_belief_bsz_ptr), "\n")

        if args["genSample"]:
            if save_dir is not "" and not os.path.exists(save_dir):
                os.mkdir(save_dir)
            json.dump(all_prediction,
                      open(
                          os.path.join(
                              save_dir, "prediction_{}_{}.json".format(
                                  self.name, save_string)), 'w'),
                      indent=4)
            print(
                "saved generated samples",
                os.path.join(
                    save_dir,
                    "prediction_{}_{}.json".format(self.name, save_string)))

        joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = self.evaluate_metrics(
            all_prediction, "pred_bs_ptr", slot_temp)

        evaluation_metrics = {
            "Joint Acc": joint_acc_score_ptr,
            "Turn Acc": turn_acc_score_ptr,
            "Joint F1": F1_score_ptr
        }
        print(evaluation_metrics)

        # Set back to training mode
        self.encoder.train(True)
        self.decoder.train(True)

        joint_acc_score = joint_acc_score_ptr  # (joint_acc_score_ptr + joint_acc_score_class)/2
        F1_score = F1_score_ptr

        if (early_stop == 'F1'):
            if (F1_score >= matric_best):
                self.save_model('ENTF1-{:.4f}'.format(F1_score))
                print("MODEL SAVED")
            return F1_score
        else:
            if (joint_acc_score >= matric_best):
                self.save_model('ACC-{:.4f}'.format(joint_acc_score))
                print("MODEL SAVED")
            return joint_acc_score

    def evaluate_metrics(self, all_prediction, from_which, slot_temp):
        total, turn_acc, joint_acc, F1_pred, F1_count = 0, 0, 0, 0, 0
        for d, v in all_prediction.items():
            for t in range(len(v)):
                cv = v[t]
                if set(cv["turn_belief"]) == set(cv[from_which]):
                    joint_acc += 1
                total += 1

                # Compute prediction slot accuracy
                temp_acc = self.compute_acc(set(cv["turn_belief"]),
                                            set(cv[from_which]), slot_temp)
                turn_acc += temp_acc

                # Compute prediction joint F1 score
                temp_f1, temp_r, temp_p, count = self.compute_prf(
                    set(cv["turn_belief"]), set(cv[from_which]))
                F1_pred += temp_f1
                F1_count += count

        joint_acc_score = joint_acc / float(total) if total != 0 else 0
        turn_acc_score = turn_acc / float(total) if total != 0 else 0
        F1_score = F1_pred / float(F1_count) if F1_count != 0 else 0
        return joint_acc_score, F1_score, turn_acc_score

    def compute_acc(self, gold, pred, slot_temp):
        miss_gold = 0
        miss_slot = []
        for g in gold:
            if g not in pred:
                miss_gold += 1
                miss_slot.append(g.rsplit("-", 1)[0])
        wrong_pred = 0
        for p in pred:
            if p not in gold and p.rsplit("-", 1)[0] not in miss_slot:
                wrong_pred += 1
        ACC_TOTAL = len(slot_temp)
        ACC = len(slot_temp) - miss_gold - wrong_pred
        ACC = ACC / float(ACC_TOTAL)
        return ACC

    def compute_prf(self, gold, pred):
        TP, FP, FN = 0, 0, 0
        if len(gold) != 0:
            count = 1
            for g in gold:
                if g in pred:
                    TP += 1
                else:
                    FN += 1
            for p in pred:
                if p not in gold:
                    FP += 1
            precision = TP / float(TP + FP) if (TP + FP) != 0 else 0
            recall = TP / float(TP + FN) if (TP + FN) != 0 else 0
            F1 = 2 * precision * recall / float(precision + recall) if (
                precision + recall) != 0 else 0
        else:
            if len(pred) == 0:
                precision, recall, F1, count = 1, 1, 1, 1
            else:
                precision, recall, F1, count = 0, 0, 0, 1
        return F1, recall, precision, count
Beispiel #21
0
def trainBERT(model, train_loader, val_loader, num_epoch=5, lr=2e-2):
    # Training steps
    start_time = time.time()
    loss_fn = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=lr, eps=1e-8)

    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    auc = []
    best_auc = 0.
    best_model = copy.deepcopy(model.state_dict())

    for epoch in range(num_epoch):
        model.train()
        #Initialize
        correct = 0
        total = 0
        total_loss = 0

        for i, (data, mask, labels) in enumerate(train_loader):
            data, mask, labels = data.to(device), mask.to(device), labels.to(
                device, dtype=torch.long)
            optimizer.zero_grad()

            outputs = model(data,
                            token_type_ids=None,
                            attention_mask=mask,
                            labels=None)

            loss = loss_fn(outputs.view(-1, 2), labels.view(-1))

            loss.backward()
            optimizer.step()
            label_cpu = labels.squeeze().to('cpu').numpy()
            pred = outputs.data.max(-1)[1].to('cpu').numpy()
            total += labels.size(0)
            correct += float(sum((pred == label_cpu)))
            total_loss += loss.item()

        acc = correct / total

        t_loss = total_loss / total
        train_loss.append(t_loss)
        train_acc.append(acc)
        # report performance

        print('Epoch: ', epoch)
        print('Train set | Accuracy: {:6.4f} | Loss: {:6.4f}'.format(
            acc, t_loss))

        # Evaluate after every epoch
        #Reset the initialization
        correct = 0
        total = 0
        total_loss = 0
        model.eval()

        predictions = []
        truths = []

        with torch.no_grad():
            for i, (data, mask, labels) in enumerate(val_loader):
                data, mask, labels = data.to(device), mask.to(
                    device), labels.to(device, dtype=torch.long)

                optimizer.zero_grad()

                outputs = model(data,
                                token_type_ids=None,
                                attention_mask=mask,
                                labels=None)
                #va_loss = loss_fn(outputs.squeeze(-1), labels)
                va_loss = loss_fn(outputs.view(-1, 2), labels.view(-1))

                label_cpu = labels.squeeze().to('cpu').numpy()

                pred = outputs.data.max(-1)[1].to('cpu').numpy()
                total += labels.size(0)
                correct += float(sum((pred == label_cpu)))
                total_loss += va_loss.item()

                predictions += list(pred)
                truths += list(label_cpu)

            v_acc = correct / total
            v_loss = total_loss / total
            val_loss.append(v_loss)
            val_acc.append(v_acc)

            v_auc = roc_auc_score(truths, predictions)
            auc.append(v_auc)

            elapse = time.strftime(
                '%H:%M:%S', time.gmtime(int((time.time() - start_time))))
            print(
                'Validation set | Accuracy: {:6.4f} | AUC: {:6.4f} | Loss: {:4.2f} | time elapse: {:>9}'
                .format(v_acc, v_auc, v_loss, elapse))
            print('-' * 10)

            if v_auc > best_auc:
                best_auc = v_auc
                best_model = copy.deepcopy(model.state_dict())

    print('Best validation auc: {:6.4f}'.format(best_auc))
    model.load_state_dict(best_model)
    return train_loss, train_acc, val_loss, val_acc, v_auc, model
Beispiel #22
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--round_name",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    round_name_2_rounds = {
        'base': ['base', 'ood'],
        'r1': ['base', 'n1', 'ood'],
        'r2': ['base', 'n1', 'n2', 'ood'],
        'r3': ['base', 'n1', 'n2', 'n3', 'ood'],
        'r4': ['base', 'n1', 'n2', 'n3', 'n4', 'ood'],
        'r5': ['base', 'n1', 'n2', 'n3', 'n4', 'n5', 'ood']
    }

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    banking77_class_list, ood_class_set, class_2_split = load_class_names()

    round_list = round_name_2_rounds.get(args.round_name)
    train_examples, base_class_list = processor.load_train(
        ['base'])  #train on base only
    '''train the first stage'''
    model = RobertaForSequenceClassification(len(base_class_list))
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    train_dataloader = examples_to_features(train_examples,
                                            base_class_list,
                                            args,
                                            tokenizer,
                                            args.train_batch_size,
                                            "classification",
                                            dataloader_mode='random')
    mean_loss = 0.0
    count = 0
    for _ in trange(int(args.num_train_epochs), desc="Stage1Epoch"):
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch

            logits = model(input_ids, input_mask, output_rep=False)
            loss_fct = CrossEntropyLoss()

            loss = loss_fct(logits.view(-1, len(base_class_list)),
                            label_ids.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            mean_loss += loss.item()
            count += 1
            # if count % 50 == 0:
            #     print('mean loss:', mean_loss/count)
    print('stage 1, train supervised classification on base is over.')
    '''now, train the second stage'''
    model_stage_2 = ModelStageTwo(len(base_class_list), model)
    model_stage_2.to(device)

    param_optimizer = list(model_stage_2.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer_stage_2 = AdamW(optimizer_grouped_parameters,
                              lr=args.learning_rate)
    mean_loss = 0.0
    count = 0
    best_threshold = 0.0
    for _ in trange(int(args.num_train_epochs), desc="Stage2Epoch"):
        '''first, select some base classes as fake novel classes'''
        fake_novel_size = 15
        fake_novel_support_size = 5
        '''for convenience, we keep shuffle the base classes, select the last 5 as fake novel'''
        original_base_class_idlist = list(range(len(base_class_list)))
        # random.shuffle(original_base_class_idlist)
        shuffled_base_class_list = [
            base_class_list[idd] for idd in original_base_class_idlist
        ]
        fake_novel_classlist = shuffled_base_class_list[-fake_novel_size:]
        '''load their support examples'''
        base_support_examples = processor.load_base_support_examples(
            fake_novel_classlist, fake_novel_support_size)
        base_support_dataloader = examples_to_features(
            base_support_examples,
            fake_novel_classlist,
            args,
            tokenizer,
            fake_novel_support_size,
            "classification",
            dataloader_mode='sequential')

        novel_class_support_reps = []
        for _, batch in enumerate(base_support_dataloader):
            input_ids, input_mask, segment_ids, label_ids = batch
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            model.eval()
            with torch.no_grad():
                support_rep_for_novel_class = model(input_ids,
                                                    input_mask,
                                                    output_rep=True)
            novel_class_support_reps.append(support_rep_for_novel_class)
        assert len(novel_class_support_reps) == fake_novel_size
        print('Extracting support reps for fake novel is over.')
        '''retrain on query set to optimize the weight generator'''
        train_dataloader = examples_to_features(train_examples,
                                                shuffled_base_class_list,
                                                args,
                                                tokenizer,
                                                args.train_batch_size,
                                                "classification",
                                                dataloader_mode='random')
        best_threshold_list = []
        for _ in range(10):  #repeat 10 times is important
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model_stage_2.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                logits = model_stage_2(
                    input_ids,
                    input_mask,
                    model,
                    novel_class_support_reps=novel_class_support_reps,
                    fake_novel_size=fake_novel_size,
                    base_class_mapping=original_base_class_idlist)
                # print('logits:', logits)
                loss_fct = CrossEntropyLoss()

                loss = loss_fct(logits.view(-1, len(base_class_list)),
                                label_ids.view(-1))
                loss.backward()
                optimizer_stage_2.step()
                optimizer_stage_2.zero_grad()
                mean_loss += loss.item()
                count += 1
                if count % 50 == 0:
                    print('mean loss:', mean_loss / count)
                scores_for_positive = logits[torch.arange(logits.shape[0]),
                                             label_ids.view(-1)].mean()
                best_threshold_list.append(scores_for_positive.item())

        best_threshold = sum(best_threshold_list) / len(best_threshold_list)

    print('stage 2 training over')
    '''
    start testing
    '''
    '''first, get reps for all base+novel classes'''
    '''support for all seen classes'''
    class_2_support_examples, seen_class_list = processor.load_support_all_rounds(
        round_list[:-1])  #no support set for ood
    assert seen_class_list[:len(base_class_list)] == base_class_list
    seen_class_list_size = len(seen_class_list)
    support_example_lists = [
        class_2_support_examples.get(seen_class)
        for seen_class in seen_class_list if seen_class not in base_class_list
    ]

    novel_class_support_reps = []
    for eval_support_examples_per_class in support_example_lists:
        support_dataloader = examples_to_features(
            eval_support_examples_per_class,
            seen_class_list,
            args,
            tokenizer,
            5,
            "classification",
            dataloader_mode='random')
        single_class_support_reps = []
        for _, batch in enumerate(support_dataloader):
            input_ids, input_mask, segment_ids, label_ids = batch
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            model.eval()
            with torch.no_grad():
                support_rep_for_novel_class = model(input_ids,
                                                    input_mask,
                                                    output_rep=True)
            single_class_support_reps.append(support_rep_for_novel_class)
        single_class_support_reps = torch.cat(single_class_support_reps,
                                              axis=0)
        novel_class_support_reps.append(single_class_support_reps)
    print('len(novel_class_support_reps):', len(novel_class_support_reps))
    print('len(base_class_list):', len(base_class_list))
    print('len(seen_class_list):', len(seen_class_list))
    assert len(novel_class_support_reps) + len(base_class_list) == len(
        seen_class_list)
    print('Extracting support reps for all  novel is over.')
    test_examples = processor.load_dev_or_test(round_list, 'test')
    test_class_list = seen_class_list + list(ood_class_set)
    print('test_class_list:', len(test_class_list))
    print('best_threshold:', best_threshold)
    test_split_list = []
    for test_class_i in test_class_list:
        test_split_list.append(class_2_split.get(test_class_i))
    test_dataloader = examples_to_features(test_examples,
                                           test_class_list,
                                           args,
                                           tokenizer,
                                           args.eval_batch_size,
                                           "classification",
                                           dataloader_mode='sequential')
    '''test on test batch '''
    preds = []
    gold_label_ids = []
    for input_ids, input_mask, segment_ids, label_ids in test_dataloader:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        label_ids = label_ids.to(device)
        gold_label_ids += list(label_ids.detach().cpu().numpy())
        model_stage_2.eval()
        with torch.no_grad():
            logits = model_stage_2(
                input_ids,
                input_mask,
                model,
                novel_class_support_reps=novel_class_support_reps,
                fake_novel_size=None,
                base_class_mapping=None)
        # print('test logits:', logits)
        if len(preds) == 0:
            preds.append(logits.detach().cpu().numpy())
        else:
            preds[0] = np.append(preds[0],
                                 logits.detach().cpu().numpy(),
                                 axis=0)

    preds = preds[0]

    pred_probs = preds  #softmax(preds,axis=1)
    pred_label_ids_raw = list(np.argmax(pred_probs, axis=1))
    pred_max_prob = list(np.amax(pred_probs, axis=1))

    pred_label_ids = []
    for i, pred_max_prob_i in enumerate(pred_max_prob):
        if pred_max_prob_i < best_threshold:
            pred_label_ids.append(
                seen_class_list_size)  #seen_class_list_size means ood
        else:
            pred_label_ids.append(pred_label_ids_raw[i])

    assert len(pred_label_ids) == len(gold_label_ids)
    acc_each_round = []
    for round_name_id in round_list:
        #base, n1, n2, ood
        round_size = 0
        rount_hit = 0
        if round_name_id != 'ood':
            for ii, gold_label_id in enumerate(gold_label_ids):
                if test_split_list[gold_label_id] == round_name_id:
                    round_size += 1
                    # print('gold_label_id:', gold_label_id, 'pred_label_ids[ii]:', pred_label_ids[ii])
                    if gold_label_id == pred_label_ids[ii]:
                        rount_hit += 1
            acc_i = rount_hit / round_size
            acc_each_round.append(acc_i)
        else:
            '''ood f1'''
            gold_binary_list = []
            pred_binary_list = []
            for ii, gold_label_id in enumerate(gold_label_ids):
                # print('gold_label_id:', gold_label_id, 'pred_label_ids[ii]:', pred_label_ids[ii])
                gold_binary_list.append(1 if test_split_list[gold_label_id] ==
                                        round_name_id else 0)
                pred_binary_list.append(1 if pred_label_ids[ii] ==
                                        seen_class_list_size else 0)
            overlap = 0
            for i in range(len(gold_binary_list)):
                if gold_binary_list[i] == 1 and pred_binary_list[i] == 1:
                    overlap += 1
            recall = overlap / (1e-6 + sum(gold_binary_list))
            precision = overlap / (1e-6 + sum(pred_binary_list))

            acc_i = 2 * recall * precision / (1e-6 + recall + precision)
            acc_each_round.append(acc_i)

    print('\n\t\t test_acc:', acc_each_round)
    final_test_performance = acc_each_round

    print('final_test_performance:', final_test_performance)
Beispiel #23
0
def train():
    # 检查配置,获取超参数
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    print("device:{} n_gpu:{}".format(device, n_gpu))
    seed = hyperparameters["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    max_seq_length = hyperparameters["max_sent_length"]
    gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"]
    num_epochs = hyperparameters["num_epoch"]
    train_batch_size = hyperparameters["train_batch_size"] // hyperparameters["gradient_accumulation_steps"]
    tokenizer = BertTokenizer.from_pretrained("bert-large-uncased", do_lower_case=True)
    model = BertForMultipleChoice.from_pretrained("bert-large-uncased")
    model.to(device)

    # 优化器
    param_optimizer = list(model.named_parameters())

    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    # 载入数据
    train_examples = read_examples('../dataset/train_bert.txt')
    dev_examples = read_examples('../dataset/test_bert.txt')
    nTrain = len(train_examples)
    nDev = len(dev_examples)
    num_train_optimization_steps = int(nTrain / train_batch_size / gradient_accumulation_steps) * num_epochs
    optimizer = AdamW(optimizer_grouped_parameters, lr=hyperparameters["learning_rate"])
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1 * num_train_optimization_steps),
                                                num_training_steps=num_train_optimization_steps)

    global_step = 0
    train_features = convert_examples_to_features(train_examples, tokenizer, max_seq_length)
    dev_features = convert_examples_to_features(dev_examples, tokenizer, max_seq_length)
    train_dataloader = get_train_dataloader(train_features, train_batch_size)
    dev_dataloader = get_eval_dataloader(dev_features, hyperparameters["eval_batch_size"])
    print("Num of train features:{}".format(nTrain))
    print("Num of dev features:{}".format(nDev))
    best_dev_accuracy = 0
    best_dev_epoch = 0
    no_up = 0

    epoch_tqdm = trange(int(num_epochs), desc="Epoch")
    for epoch in epoch_tqdm:
        model.train()

        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            batch = tuple(t.to(device) for t in batch)
            input_ids, label_ids = batch
            loss, logits = model(input_ids=input_ids, labels=label_ids)[:2]
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps
            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1
            loss.backward()
            if (step + 1) % gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

        train_loss, train_accuracy = evaluate(model, device, train_dataloader, "Train")
        dev_loss, dev_accuracy = evaluate(model, device, dev_dataloader, "Dev")

        if dev_accuracy > best_dev_accuracy:
            best_dev_accuracy = dev_accuracy
            best_dev_epoch = epoch + 1
            no_up = 0

        else:
            no_up += 1
        tqdm.write("\t ***** Eval results (Epoch %s) *****" % str(epoch + 1))
        tqdm.write("\t train_accuracy = %s" % str(train_accuracy))
        tqdm.write("\t dev_accuracy = %s" % str(dev_accuracy))
        tqdm.write("")
        tqdm.write("\t best_dev_accuracy = %s" % str(best_dev_accuracy))
        tqdm.write("\t best_dev_epoch = %s" % str(best_dev_epoch))
        tqdm.write("\t no_up = %s" % str(no_up))
        tqdm.write("")
        if no_up >= hyperparameters["patience"]:
            epoch_tqdm.close()
            break
Beispiel #24
0
class Framework(object):
    """A framework wrapping the Relational Graph Extraction model. This framework allows to train, predict, evaluate,
    saving and loading the model with a single line of code.
    """
    def __init__(self, **config):
        super().__init__()

        self.config = config

        self.grad_acc = self.config[
            'grad_acc'] if 'grad_acc' in self.config else 1
        self.device = torch.device(self.config['device'])
        if isinstance(self.config['model'], str):
            self.model = MODELS[self.config['model']](**self.config)
        else:
            self.model = self.config['model']

        self.class_weights = torch.tensor(self.config['class_weights']).float(
        ) if 'class_weights' in self.config else torch.ones(
            self.config['n_rel'])
        if 'lambda' in self.config:
            self.class_weights[0] = self.config['lambda']
        self.loss_fn = nn.CrossEntropyLoss(weight=self.class_weights.to(
            self.device),
                                           reduction='mean')
        if self.config['optimizer'] == 'SGD':
            self.optimizer = torch.optim.SGD(
                self.model.get_parameters(self.config.get('l2', .01)),
                lr=self.config['lr'],
                momentum=self.config.get('momentum', 0),
                nesterov=self.config.get('nesterov', False))
        elif self.config['optimizer'] == 'Adam':
            self.optimizer = torch.optim.Adam(self.model.get_parameters(
                self.config.get('l2', .01)),
                                              lr=self.config['lr'])
        elif self.config['optimizer'] == 'AdamW':
            self.optimizer = AdamW(self.model.get_parameters(
                self.config.get('l2', .01)),
                                   lr=self.config['lr'])
        else:
            raise Exception('The optimizer must be SGD, Adam or AdamW')

    def _train_step(self, dataset, epoch, scheduler=None):
        print("Training:")
        self.model.train()

        total_loss = 0
        predictions, labels, positions = [], [], []
        precision = recall = fscore = 0.0
        progress = tqdm(
            enumerate(dataset),
            desc=
            f"Epoch: {epoch} - Loss: {0.0} - P/R/F: {precision}/{recall}/{fscore}",
            total=len(dataset))
        for i, batch in progress:
            # uncompress the batch
            seq, mask, ent, label = batch
            seq = seq.to(self.device)
            mask = mask.to(self.device)
            ent = ent.to(self.device)
            label = label.to(self.device)

            #self.optimizer.zero_grad()
            output = self.model(seq, mask, ent)
            loss = self.loss_fn(output, label)
            total_loss += loss.item()

            if self.config['half']:
                with amp.scale_loss(loss, self.optimizer) as scale_loss:
                    scale_loss.backward()
            else:
                loss.backward()

            if (i + 1) % self.grad_acc == 0:
                if self.config.get('grad_clip', False):
                    if self.config['half']:
                        nn.utils.clip_grad_norm_(
                            amp.master_params(self.optimizer),
                            self.config['grad_clip'])
                    else:
                        nn.utils.clip_grad_norm_(self.model.parameters(),
                                                 self.config['grad_clip'])

                self.optimizer.step()
                self.model.zero_grad()
                if scheduler:
                    scheduler.step()

            # Evaluate results
            pre, lab, pos = dataset.evaluate(
                i,
                output.detach().numpy() if self.config['device'] is 'cpu' else
                output.detach().cpu().numpy())

            predictions.extend(pre)
            labels.extend(lab)
            positions.extend(pos)

            if (i + 1) % 10 == 0:
                precision, recall, fscore, _ = precision_recall_fscore_support(
                    np.array(labels),
                    np.array(predictions),
                    average='micro',
                    labels=list(range(1, self.model.n_rel)))

            progress.set_description(
                f"Epoch: {epoch} - Loss: {total_loss/(i+1):.3f} - P/R/F: {precision:.2f}/{recall:.2f}/{fscore:.2f}"
            )

        # For last iteration
        #self.optimizer.step()
        #self.optimizer.zero_grad()

        predictions, labels = np.array(predictions), np.array(labels)
        precision, recall, fscore, _ = precision_recall_fscore_support(
            labels,
            predictions,
            average='micro',
            labels=list(range(1, self.model.n_rel)))
        print(
            f"Precision: {precision:.3f} - Recall: {recall:.3f} - F-Score: {fscore:.3f}"
        )
        precision, recall, fscore, _ = precision_recall_fscore_support(
            labels, predictions, average='micro')
        print(
            f"[with NO-RELATION] Precision: {precision:.3f} - Recall: {recall:.3f} - F-Score: {fscore:.3f}"
        )

        return total_loss / (i + 1)

    def _val_step(self, dataset, epoch):
        print("Validating:")
        self.model.eval()

        predictions, labels, positions = [], [], []
        total_loss = 0
        with torch.no_grad():
            progress = tqdm(enumerate(dataset),
                            desc=f"Epoch: {epoch} - Loss: {0.0}",
                            total=len(dataset))
            for i, batch in progress:
                # uncompress the batch
                seq, mask, ent, label = batch
                seq = seq.to(self.device)
                mask = mask.to(self.device)
                ent = ent.to(self.device)
                label = label.to(self.device)

                output = self.model(seq, mask, ent)
                loss = self.loss_fn(output, label)
                total_loss += loss.item()

                # Evaluate results
                pre, lab, pos = dataset.evaluate(
                    i,
                    output.detach().numpy() if self.config['device'] is 'cpu'
                    else output.detach().cpu().numpy())

                predictions.extend(pre)
                labels.extend(lab)
                positions.extend(pos)

                progress.set_description(
                    f"Epoch: {epoch} - Loss: {total_loss/(i+1):.3f}")

        predictions, labels = np.array(predictions), np.array(labels)
        precision, recall, fscore, _ = precision_recall_fscore_support(
            labels,
            predictions,
            average='micro',
            labels=list(range(1, self.model.n_rel)))
        print(
            f"Precision: {precision:.3f} - Recall: {recall:.3f} - F-Score: {fscore:.3f}"
        )
        noprecision, norecall, nofscore, _ = precision_recall_fscore_support(
            labels, predictions, average='micro')
        print(
            f"[with NO-RELATION] Precision: {noprecision:.3f} - Recall: {norecall:.3f} - F-Score: {nofscore:.3f}"
        )

        return total_loss / (i + 1), precision, recall, fscore

    def _save_checkpoint(self, dataset, epoch, loss, val_loss):
        print(f"Saving checkpoint ({dataset.name}.pth) ...")
        PATH = os.path.join('checkpoints', f"{dataset.name}.pth")
        config_PATH = os.path.join('checkpoints',
                                   f"{dataset.name}_config.json")
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': loss,
                'val_loss': val_loss
            }, PATH)
        with open(config_PATH, 'wt') as f:
            json.dump(self.config, f)

    def _load_checkpoint(self, PATH: str, config_PATH: str):
        checkpoint = torch.load(PATH)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']

        with open(config_PATH, 'rt') as f:
            self.config = json.load(f)

        return epoch, loss

    def fit(self,
            dataset,
            validation=True,
            batch_size=1,
            patience=3,
            delta=0.):
        """ Fits the model to the given dataset.

        Usage:
        ``` y
        >>> rge = Framework(**config)
        >>> rge.fit(train_data)
        """
        self.model.to(self.device)
        train_data = dataset.get_train(batch_size)

        if self.config['half']:
            self.model, self.optimizer = amp.initialize(
                self.model,
                self.optimizer,
                opt_level='O2',
                keep_batchnorm_fp32=True)

        if self.config['linear_scheduler']:
            num_training_steps = int(
                len(train_data) // self.grad_acc * self.config['epochs'])
            scheduler = get_linear_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=self.config.get('warmup_steps', 0),
                num_training_steps=num_training_steps)
        else:
            scheduler = None

        early_stopping = EarlyStopping(patience, delta, self._save_checkpoint)

        for epoch in range(self.config['epochs']):
            self.optimizer.zero_grad()
            loss = self._train_step(train_data, epoch, scheduler=scheduler)
            if validation:
                val_loss, _, _, _ = self._val_step(dataset.get_val(batch_size),
                                                   epoch)
                if early_stopping(val_loss,
                                  dataset=dataset,
                                  epoch=epoch,
                                  loss=loss):
                    break

        # Recover the best epoch
        path = os.path.join("checkpoints", f"{dataset.name}.pth")
        config_path = os.path.join("checkpoints",
                                   f"{dataset.name}_config.json")
        _, _ = self._load_checkpoint(path, config_path)

    def predict(self, dataset, return_proba=False) -> torch.Tensor:
        """ Predicts the relations graph for the given dataset.
        """
        self.model.to(self.device)
        self.model.eval()

        predictions, instances = [], []
        with torch.no_grad():
            progress = tqdm(enumerate(dataset), total=len(dataset))
            for i, batch in progress:
                # uncompress the batch
                seq, mask, ent, label = batch
                seq = seq.to(self.device)
                mask = mask.to(self.device)
                ent = ent.to(self.device)
                label = label.to(self.device)

                output = self.model(seq, mask, ent)
                if not return_proba:
                    pred = np.argmax(output.detach().cpu().numpy(),
                                     axis=1).tolist()
                else:
                    pred = output.detach().cpu().numpy().tolist()
                inst = dataset.get_instances(i)

                predictions.extend(pred)
                instances.extend(inst)

        return predictions, instances

    def evaluate(self, dataset: Dataset, batch_size=1) -> torch.Tensor:
        """ Evaluates the model given for the given dataset.
        """
        loss, precision, recall, fscore = self._val_step(
            dataset.get_val(batch_size), 0)
        return loss, precision, recall, fscore

    def save_model(self, path: str):
        """ Saves the model to a file.

        Usage:
        ``` 
        >>> rge = Framework(**config)
        >>> rge.fit(train_data)

        >>> rge.save_model("path/to/file")
        ```

        TODO
        """
        self.model.save_pretrained(path)
        with open(f"{path}/fine_tunning.config.json", 'wt') as f:
            json.dump(self.config, f, indent=4)

    @classmethod
    def load_model(cls,
                   path: str,
                   config_path: str = None,
                   from_checkpoint=False):
        """ Loads the model from a file.

        Args:
            path: str Path to the file that stores the model.

        Returns:
            Framework instance with the loaded model.

        Usage:
        ```
        >>> rge = Framework.load_model("path/to/model")
        ```

        TODO
        """
        if not from_checkpoint:
            config_path = path + '/fine_tunning.config.json'
            with open(config_path) as f:
                config = json.load(f)
            config['pretrained_model'] = path
            rge = cls(**config)

        else:
            if config_path is None:
                raise Exception(
                    'Loading the model from a checkpoint requires config_path argument.'
                )
            with open(config_path) as f:
                config = json.load(f)
            rge = cls(**config)
            rge._load_checkpoint(path, config_path)

        return rge
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--round_name",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=50,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")

    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    round_name_2_rounds = {
        'r1': ['n1', 'ood'],
        'r2': ['n1', 'n2', 'ood'],
        'r3': ['n1', 'n2', 'n3', 'ood'],
        'r4': ['n1', 'n2', 'n3', 'n4', 'ood'],
        'r5': ['n1', 'n2', 'n3', 'n4', 'n5', 'ood']
    }

    model = RobertaForSequenceClassification(3)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.load_state_dict(torch.load('../../data/MNLI_pretrained.pt'),
                          strict=False)
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    banking77_class_list, ood_class_set, class_2_split = load_class_names()

    round_list = round_name_2_rounds.get(args.round_name)
    '''load training in list'''
    train_examples_list, train_class_list, train_class_2_split_list, class_2_sentlist_upto_this_round = processor.load_train(
        round_list[:-1])  # no odd training examples
    assert len(train_class_list) == len(train_class_2_split_list)
    # assert len(train_class_list) ==  20+(len(round_list)-2)*10
    '''dev and test'''
    dev_examples, dev_instance_size = processor.load_dev_or_test(
        round_list, train_class_list, class_2_sentlist_upto_this_round, 'dev')
    test_examples, test_instance_size = processor.load_dev_or_test(
        round_list, train_class_list, class_2_sentlist_upto_this_round, 'test')
    print('train size:', [len(train_i) for train_i in train_examples_list],
          ' dev size:', len(dev_examples), ' test size:', len(test_examples))
    entail_class_list = ['entailment', 'non-entailment']
    eval_class_list = train_class_list + ['ood']
    test_split_list = train_class_2_split_list + ['ood']
    train_dataloader_list = []
    for train_examples in train_examples_list:
        train_dataloader = examples_to_features(train_examples,
                                                entail_class_list,
                                                eval_class_list,
                                                args,
                                                tokenizer,
                                                args.train_batch_size,
                                                "classification",
                                                dataloader_mode='random')
        train_dataloader_list.append(train_dataloader)
    dev_dataloader = examples_to_features(dev_examples,
                                          entail_class_list,
                                          eval_class_list,
                                          args,
                                          tokenizer,
                                          args.eval_batch_size,
                                          "classification",
                                          dataloader_mode='sequential')
    test_dataloader = examples_to_features(test_examples,
                                           entail_class_list,
                                           eval_class_list,
                                           args,
                                           tokenizer,
                                           args.eval_batch_size,
                                           "classification",
                                           dataloader_mode='sequential')
    '''training'''
    max_test_acc = 0.0
    max_dev_acc = 0.0
    for round_index, round in enumerate(round_list[:-1]):
        '''for the new examples in each round, train multiple epochs'''
        train_dataloader = train_dataloader_list[round_index]
        for epoch_i in range(args.num_train_epochs):
            for _, batch in enumerate(
                    tqdm(train_dataloader,
                         desc="train|" + round + '|epoch_' + str(epoch_i))):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                _, input_ids, input_mask, _, label_ids, _, _ = batch

                logits = model(input_ids, input_mask)
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, 3), label_ids.view(-1))
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
        print('\t\t round ', round, ' is over...')
    '''evaluation'''
    model.eval()
    '''test'''
    acc_each_round = []
    preds = []
    gold_guids = []
    gold_premise_ids = []
    gold_hypothesis_ids = []
    for _, batch in enumerate(tqdm(test_dataloader, desc="test")):
        guids, input_ids, input_mask, _, label_ids, premise_class_ids, hypothesis_class_id = batch
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)

        gold_guids += list(guids.detach().cpu().numpy())
        gold_premise_ids += list(premise_class_ids.detach().cpu().numpy())
        gold_hypothesis_ids += list(hypothesis_class_id.detach().cpu().numpy())

        with torch.no_grad():
            logits = model(input_ids, input_mask)
        if len(preds) == 0:
            preds.append(logits.detach().cpu().numpy())
        else:
            preds[0] = np.append(preds[0],
                                 logits.detach().cpu().numpy(),
                                 axis=0)
    preds = softmax(preds[0], axis=1)

    pred_label_3way = np.argmax(preds,
                                axis=1)  #dev_examples, 0 means "entailment"
    pred_probs = list(
        preds[:, 0])  #prob for "entailment" class: (#input, #seen_classe)
    assert len(pred_label_3way) == len(test_examples)
    assert len(pred_probs) == len(test_examples)
    assert len(gold_premise_ids) == len(test_examples)
    assert len(gold_hypothesis_ids) == len(test_examples)
    assert len(gold_guids) == len(test_examples)

    guid_2_premise_idlist = defaultdict(list)
    guid_2_hypoID_2_problist_labellist = {}
    for guid_i, threeway_i, prob_i, premise_i, hypo_i in zip(
            gold_guids, pred_label_3way, pred_probs, gold_premise_ids,
            gold_hypothesis_ids):
        guid_2_premise_idlist[guid_i].append(premise_i)
        hypoID_2_problist_labellist = guid_2_hypoID_2_problist_labellist.get(
            guid_i)
        if hypoID_2_problist_labellist is None:
            hypoID_2_problist_labellist = {}
        lists = hypoID_2_problist_labellist.get(hypo_i)
        if lists is None:
            lists = [[], []]
        lists[0].append(prob_i)
        lists[1].append(threeway_i)
        hypoID_2_problist_labellist[hypo_i] = lists
        guid_2_hypoID_2_problist_labellist[
            guid_i] = hypoID_2_problist_labellist

    pred_label_ids = []
    gold_label_ids = []
    for guid in range(test_instance_size):
        assert len(set(guid_2_premise_idlist.get(guid))) == 1
        gold_label_ids.append(guid_2_premise_idlist.get(guid)[0])
        '''infer predict label id'''
        hypoID_2_problist_labellist = guid_2_hypoID_2_problist_labellist.get(
            guid)

        final_max_mean_prob = 0.0
        final_hypo_id = -1
        for hypo_id, problist_labellist in hypoID_2_problist_labellist.items():
            problist = problist_labellist[0]
            mean_prob = np.mean(problist)
            labellist = problist_labellist[1]
            same_cluter_times = labellist.count(
                0)  #'entailment' is the first label
            same_cluter = False
            if same_cluter_times / len(labellist) > 0.5:
                same_cluter = True

            if same_cluter is True and mean_prob > final_max_mean_prob:
                final_max_mean_prob = mean_prob
                final_hypo_id = hypo_id
        if final_hypo_id != -1:  # can find a class that it belongs to
            pred_label_ids.append(final_hypo_id)
        else:
            pred_label_ids.append(len(train_class_list))

    assert len(pred_label_ids) == len(gold_label_ids)
    acc_each_round = []
    for round_name_id in round_list:
        #base, n1, n2, ood
        round_size = 0
        rount_hit = 0
        if round_name_id != 'ood':
            for ii, gold_label_id in enumerate(gold_label_ids):
                if test_split_list[gold_label_id] == round_name_id:
                    round_size += 1
                    if gold_label_id == pred_label_ids[ii]:
                        rount_hit += 1
            acc_i = rount_hit / round_size
            acc_each_round.append(acc_i)
        else:
            '''ood acc'''
            gold_binary_list = []
            pred_binary_list = []
            for ii, gold_label_id in enumerate(gold_label_ids):
                gold_binary_list.append(1 if test_split_list[gold_label_id] ==
                                        round_name_id else 0)
                pred_binary_list.append(1 if pred_label_ids[ii] ==
                                        len(train_class_list) else 0)
            overlap = 0
            for i in range(len(gold_binary_list)):
                if gold_binary_list[i] == 1 and pred_binary_list[i] == 1:
                    overlap += 1
            recall = overlap / (1e-6 + sum(gold_binary_list))
            precision = overlap / (1e-6 + sum(pred_binary_list))
            acc_i = 2 * recall * precision / (1e-6 + recall + precision)
            acc_each_round.append(acc_i)

    print('final_test_performance:', acc_each_round)
Beispiel #26
0
class CXRBERT_Trainer():
    def __init__(self, args, train_dataloader, test_dataloader=None):
        self.args = args

        cuda_condition = torch.cuda.is_available() and args.with_cuda

        self.device = torch.device("cuda" if cuda_condition else "cpu")
        print('Current cuda device ', torch.cuda.current_device())  # check

        if args.weight_load:
            config = AutoConfig.from_pretrained(args.pre_trained_model_path)
            model_state_dict = torch.load(
                os.path.join(args.pre_trained_model_path, 'pytorch_model.bin'))
            self.model = CXRBERT.from_pretrained(args.pre_trained_model_path,
                                                 state_dict=model_state_dict,
                                                 config=config,
                                                 args=args).to(self.device)
            print('training restart with mid epoch')
            print(config)
        else:
            if args.bert_model == "albert-base-v2":
                config = AlbertConfig.from_pretrained(args.bert_model)
            elif args.bert_model == "emilyalsentzer/Bio_ClinicalBERT":
                config = AutoConfig.from_pretrained(args.bert_model)
            elif args.bert_model == "bionlp/bluebert_pubmed_mimic_uncased_L-12_H-768_A-12":
                config = AutoConfig.from_pretrained(args.bert_model)
            elif args.bert_model == "bert-small-scratch":
                config = BertConfig.from_pretrained(
                    "google/bert_uncased_L-4_H-512_A-8")
            elif args.bert_model == "bert-base-scratch":
                config = BertConfig.from_pretrained("bert-base-uncased")
            else:
                config = BertConfig.from_pretrained(
                    args.bert_model)  # bert-base, small, tiny

            self.model = CXRBERT(config, args).to(self.device)

        wandb.watch(self.model)

        if args.with_cuda and torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model,
                                         device_ids=args.cuda_devices)

        self.train_data = train_dataloader
        self.test_data = test_dataloader

        self.optimizer = AdamW(self.model.parameters(), lr=args.lr)

        self.mlm_criterion = nn.CrossEntropyLoss(ignore_index=-100)
        self.itm_criterion = nn.CrossEntropyLoss()

        self.log_freq = args.log_freq
        self.step_cnt = 0

        print("Total Parameters:",
              sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch):

        self.model.train()

        train_losses = []
        train_itm_loss = []
        train_mlm_loss = []

        train_data_iter = tqdm.tqdm(enumerate(self.train_data),
                                    desc=f'EP_:{epoch}',
                                    total=len(self.train_data),
                                    bar_format='{l_bar}{r_bar}')
        total_correct = 0
        total_element = 0
        total_mlm_correct = 0
        total_mlm_element = 0

        total_valid_correct = 0
        total_valid_element = 0
        total_mlm_valid_correct = 0
        total_mlm_valid_element = 0

        for i, data in train_data_iter:

            cls_tok, input_ids, txt_labels, attn_masks, img, segment, is_aligned, sep_tok, itm_prob = data

            cls_tok = cls_tok.to(self.device)
            input_ids = input_ids.to(self.device)
            txt_labels = txt_labels.to(self.device)
            attn_masks = attn_masks.to(self.device)
            img = img.to(self.device)
            segment = segment.to(self.device)
            is_aligned = is_aligned.to(self.device)
            sep_tok = sep_tok.to(self.device)

            mlm_output, itm_output = self.model(cls_tok, input_ids, attn_masks,
                                                segment, img, sep_tok)

            if self.args.mlm_task and self.args.itm_task == False:
                mlm_loss = self.mlm_criterion(mlm_output.transpose(1, 2),
                                              txt_labels)
                loss = mlm_loss
                print('only mlm_loss')

            if self.args.itm_task and self.args.mlm_task == False:
                itm_loss = self.itm_criterion(itm_output, is_aligned)
                loss = itm_loss
                print('only itm_loss')

            if self.args.mlm_task and self.args.itm_task:

                mlm_loss = self.mlm_criterion(mlm_output.transpose(1, 2),
                                              txt_labels)
                train_mlm_loss.append(mlm_loss.item())

                itm_loss = self.itm_criterion(itm_output, is_aligned)
                train_itm_loss.append(itm_loss.item())

                loss = itm_loss + mlm_loss

            train_losses.append(loss.item())
            self.optimizer.zero_grad()  # above
            loss.backward()
            self.optimizer.step()

            if self.args.itm_task:
                correct = itm_output.argmax(dim=-1).eq(is_aligned).sum().item()
                total_correct += correct
                total_element += is_aligned.nelement()

            if self.args.mlm_task:
                eq = (mlm_output.argmax(dim=-1).eq(txt_labels)).cpu().numpy()
                txt_labels_np = txt_labels.cpu().numpy()
                for bs, label in enumerate(txt_labels_np):
                    index = np.where(label == -100)[0]
                    f_label = np.delete(label, index)
                    f_eq = np.delete(eq[bs], index)
                    total_mlm_correct += f_eq.sum()
                    total_mlm_element += len(f_label)

        print("avg loss per epoch", np.mean(train_losses))
        print("avg itm acc per epoch",
              round(total_correct / total_element * 100, 3))
        if self.args.mlm_task and self.args.itm_task:
            wandb.log(
                {
                    "avg_loss": np.mean(train_losses),
                    "avg_mlm_loss": np.mean(train_mlm_loss),
                    "avg_itm_loss": np.mean(train_itm_loss),
                    "itm_acc": total_correct / total_element * 100,
                    "mlm_acc": total_mlm_correct / total_mlm_element * 100
                },
                step=epoch)

        if self.args.itm_task and self.args.mlm_task == False:
            wandb.log(
                {
                    "avg_loss": np.mean(train_losses),
                    "itm_epoch_acc": total_correct / total_element * 100
                },
                step=epoch)

        if self.args.mlm_task and self.args.itm_task == False:
            wandb.log(
                {
                    "avg_loss": np.mean(train_losses),
                    "mlm_epoch_acc":
                    total_mlm_correct / total_mlm_element * 100
                },
                step=epoch)

        test_data_iter = tqdm.tqdm(enumerate(self.test_data),
                                   desc=f'EP_:{epoch}',
                                   total=len(self.test_data),
                                   bar_format='{l_bar}{r_bar}')
        self.model.eval()
        with torch.no_grad():
            eval_losses = []
            eval_mlm_loss = []
            eval_itm_loss = []
            for i, data in test_data_iter:
                cls_tok, input_ids, txt_labels, attn_masks, img, segment, is_aligned, sep_tok, itm_prob = data

                cls_tok = cls_tok.to(self.device)
                input_ids = input_ids.to(self.device)
                txt_labels = txt_labels.to(self.device)
                attn_masks = attn_masks.to(self.device)
                img = img.to(self.device)
                segment = segment.to(self.device)
                is_aligned = is_aligned.to(self.device)
                sep_tok = sep_tok.to(self.device)

                mlm_output, itm_output = self.model(cls_tok, input_ids,
                                                    attn_masks, segment, img,
                                                    sep_tok)

                if self.args.mlm_task and self.args.itm_task == False:
                    valid_mlm_loss = self.mlm_criterion(
                        mlm_output.transpose(1, 2), txt_labels)
                    valid_loss = valid_mlm_loss
                    print('only valid mlm loss')

                if self.args.itm_task and self.args.mlm_task == False:
                    valid_itm_loss = self.itm_criterion(itm_output, is_aligned)
                    valid_loss = valid_itm_loss
                    print('only valid itm loss')

                if self.args.mlm_task and self.args.itm_task:
                    # TODO: weight each loss, mlm > itm
                    valid_mlm_loss = self.mlm_criterion(
                        mlm_output.transpose(1, 2), txt_labels)
                    valid_itm_loss = self.itm_criterion(itm_output, is_aligned)
                    eval_mlm_loss.append(valid_mlm_loss.item())
                    eval_itm_loss.append(valid_itm_loss.item())

                    valid_loss = valid_itm_loss + valid_mlm_loss

                eval_losses.append(valid_loss.item())

                if self.args.itm_task:
                    valid_correct = itm_output.argmax(
                        dim=-1).eq(is_aligned).sum().item()
                    total_valid_correct += valid_correct
                    total_valid_element += is_aligned.nelement()

                if self.args.mlm_task:
                    eq = (mlm_output.argmax(
                        dim=-1).eq(txt_labels)).cpu().numpy()
                    txt_labels_np = txt_labels.cpu().numpy()
                    for bs, label in enumerate(txt_labels_np):
                        index = np.where(label == -100)[0]
                        f_label = np.delete(label, index)
                        f_eq = np.delete(eq[bs], index)
                        total_mlm_valid_correct += f_eq.sum()
                        total_mlm_valid_element += len(f_label)

            print("avg loss in testset", np.mean(eval_losses))
            print("avg itm acc in testset",
                  round(total_valid_correct / total_valid_element * 100, 3))

            if self.args.mlm_task and self.args.itm_task:
                wandb.log(
                    {
                        "eval_avg_loss":
                        np.mean(eval_losses),
                        "eval_mlm_loss":
                        np.mean(eval_mlm_loss),
                        "eval_itm_loss":
                        np.mean(eval_itm_loss),
                        "eval_itm_acc":
                        total_valid_correct / total_valid_element * 100,
                        "eval_mlm_acc":
                        total_mlm_valid_correct / total_mlm_valid_element * 100
                    },
                    step=epoch)

            if self.args.itm_task and self.args.mlm_task == False:
                wandb.log(
                    {
                        "eval_avg_loss":
                        np.mean(eval_losses),
                        "eval_itm_epoch_acc":
                        total_valid_correct / total_valid_element * 100
                    },
                    step=epoch)

            if self.args.mlm_task and self.args.itm_task == False:
                wandb.log(
                    {
                        "eval_avg_loss":
                        np.mean(eval_losses),
                        "eval_mlm_epoch_acc":
                        total_mlm_valid_correct / total_mlm_valid_element * 100
                    },
                    step=epoch)

    def save(self, epoch, file_path):
        save_path_per_ep = os.path.join(file_path, str(epoch))
        if not os.path.exists(save_path_per_ep):
            os.mkdir(save_path_per_ep)
            os.chmod(save_path_per_ep, 0o777)

        if torch.cuda.device_count() > 1:
            self.model.module.save_pretrained(save_path_per_ep)
            print(f'Multi_EP: {epoch} Model saved on {save_path_per_ep}')
        else:
            self.model.save_pretrained(save_path_per_ep)
            print(f'Single_EP: {epoch} Model saved on {save_path_per_ep}')
        os.chmod(save_path_per_ep + '/pytorch_model.bin', 0o777)
def train():
    device = Config.device
    # 准备数据
    train_data, dev_data = build_dataset(Config)
    train_iter = DatasetIterater(train_data, Config)
    dev_iter = DatasetIterater(dev_data, Config)

    model = Model().to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    # optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    # 这里我们用bertAdam优化器
    optimizer = AdamW(
        optimizer_grouped_parameters,
        lr=Config.learning_rate,
        correct_bias=False)  # 要重现BertAdam特定的行为,请设置correct_bias = False
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0.05,
        num_training_steps=len(train_iter) *
        Config.num_epochs)  # PyTorch调度程序用法如下:

    model.to(device)
    model.train()

    best_loss = 100000.0
    for epoch in range(Config.num_epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, Config.num_epochs))
        for step, batch in enumerate(train_iter):
            start_time = time.time()
            ids, input_ids, input_mask, start_positions, end_positions = \
                batch[0], batch[1], batch[2], batch[3], batch[4]
            input_ids, input_mask, start_positions, end_positions = \
                input_ids.to(device), input_mask.to(device), start_positions.to(device), end_positions.to(device)

            # print(input_ids.size())
            # print(input_mask.size())
            # print(start_positions.size())
            # print(end_positions.size())

            loss, _, _ = model(input_ids,
                               attention_mask=input_mask,
                               start_positions=start_positions,
                               end_positions=end_positions)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_grad_norm=20)
            optimizer.step()
            scheduler.step()

            time_str = datetime.datetime.now().isoformat()
            log_str = 'time:{}, epoch:{}, step:{}, loss:{:8f}, spend_time:{:6f}'.format(
                time_str, epoch, step, loss,
                time.time() - start_time)
            rainbow(log_str)

            train_loss.append(loss)

        if epoch % 1 == 0:
            eval_loss = valid(model, dev_iter)
            if eval_loss < best_loss:
                best_loss = eval_loss
                torch.save(model.state_dict(),
                           './save_model/' + 'best_model.bin')
                model.train()
Beispiel #28
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, n_gpu)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, n_gpu)
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()

    num_labels = len(["entailment", "neutral", "contradiction"])
    # pretrain_model_dir = 'roberta-large' #'roberta-large' , 'roberta-large-mnli'
    pretrain_model_dir = '/export/home/Dataset/BERT_pretrained_mine/TrainedModelReminder/RoBERTa_on_MNLI_SNLI_SciTail_RTE_ANLI_SpecialToken_epoch_2_acc_4.156359461121103'  #'roberta-large' , 'roberta-large-mnli'
    model = RobertaForSequenceClassification.from_pretrained(
        pretrain_model_dir, num_labels=num_labels)
    tokenizer = RobertaTokenizer.from_pretrained(
        pretrain_model_dir, do_lower_case=args.do_lower_case)
    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    #MNLI-SNLI-SciTail-RTE-SICK
    train_examples_MNLI, dev_examples_MNLI = processor.get_MNLI_train_and_dev(
        '/export/home/Dataset/glue_data/MNLI/train.tsv',
        '/export/home/Dataset/glue_data/MNLI/dev_mismatched.tsv'
    )  #train_pu_half_v1.txt
    train_examples_SNLI, dev_examples_SNLI = processor.get_SNLI_train_and_dev(
        '/export/home/Dataset/glue_data/SNLI/train.tsv',
        '/export/home/Dataset/glue_data/SNLI/dev.tsv')
    train_examples_SciTail, dev_examples_SciTail = processor.get_SciTail_train_and_dev(
        '/export/home/Dataset/SciTailV1/tsv_format/scitail_1.0_train.tsv',
        '/export/home/Dataset/SciTailV1/tsv_format/scitail_1.0_dev.tsv')
    train_examples_RTE, dev_examples_RTE = processor.get_RTE_train_and_dev(
        '/export/home/Dataset/glue_data/RTE/train.tsv',
        '/export/home/Dataset/glue_data/RTE/dev.tsv')
    train_examples_ANLI, dev_examples_ANLI = processor.get_ANLI_train_and_dev(
        'train', 'dev',
        '/export/home/Dataset/para_entail_datasets/ANLI/anli_v0.1/')

    train_examples = train_examples_MNLI + train_examples_SNLI + train_examples_SciTail + train_examples_RTE + train_examples_ANLI
    dev_examples_list = [
        dev_examples_MNLI, dev_examples_SNLI, dev_examples_SciTail,
        dev_examples_RTE, dev_examples_ANLI
    ]

    dev_task_label = [0, 0, 1, 1, 0]
    task_names = ['MNLI', 'SNLI', 'SciTail', 'RTE', 'ANLI']
    '''filter challenging neighbors'''
    neighbor_id_list = []
    readfile = codecs.open('neighbors_indices_before_dropout_eud.v3.txt', 'r',
                           'utf-8')
    for line in readfile:
        neighbor_id_list.append(int(line.strip()))
    readfile.close()
    print('neighbor_id_list size:', len(neighbor_id_list))
    truncated_train_examples = [train_examples[i] for i in neighbor_id_list]
    train_examples = truncated_train_examples

    num_train_optimization_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps) * args.num_train_epochs
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0

    train_features = convert_examples_to_features(
        train_examples,
        label_list,
        args.max_seq_length,
        tokenizer,
        output_mode,
        cls_token_at_end=
        False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
        cls_token=tokenizer.cls_token,
        cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
        sep_token=tokenizer.sep_token,
        sep_token_extra=
        True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
        pad_on_left=
        False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
        pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
        pad_token_segment_id=0)  #4 if args.model_type in ['xlnet'] else 0,)

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_examples))
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Num steps = %d", num_train_optimization_steps)
    all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                   dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in train_features],
                                 dtype=torch.long)
    all_task_label_ids = torch.tensor([f.task_label for f in train_features],
                                      dtype=torch.long)

    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                               all_label_ids, all_task_label_ids)
    train_sampler = RandomSampler(train_data)

    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  drop_last=True)
    '''dev data to features'''
    valid_dataloader_list = []
    for valid_examples_i in dev_examples_list:
        valid_features = convert_examples_to_features(
            valid_examples_i,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        logger.info("***** valid_examples *****")
        logger.info("  Num examples = %d", len(valid_examples_i))
        valid_input_ids = torch.tensor([f.input_ids for f in valid_features],
                                       dtype=torch.long)
        valid_input_mask = torch.tensor([f.input_mask for f in valid_features],
                                        dtype=torch.long)
        valid_segment_ids = torch.tensor(
            [f.segment_ids for f in valid_features], dtype=torch.long)
        valid_label_ids = torch.tensor([f.label_id for f in valid_features],
                                       dtype=torch.long)
        valid_task_label_ids = torch.tensor(
            [f.task_label for f in valid_features], dtype=torch.long)

        valid_data = TensorDataset(valid_input_ids, valid_input_mask,
                                   valid_segment_ids, valid_label_ids,
                                   valid_task_label_ids)
        valid_sampler = SequentialSampler(valid_data)
        valid_dataloader = DataLoader(valid_data,
                                      sampler=valid_sampler,
                                      batch_size=args.eval_batch_size)
        valid_dataloader_list.append(valid_dataloader)

    iter_co = 0
    for epoch_i in trange(int(args.num_train_epochs), desc="Epoch"):
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids, task_label_ids = batch
            logits = model(input_ids, input_mask, None, labels=None)

            prob_matrix = F.log_softmax(logits[0].view(-1, num_labels), dim=1)
            '''this step *1.0 is very important, otherwise bug'''
            new_prob_matrix = prob_matrix * 1.0
            '''change the entail prob to p or 1-p'''
            changed_places = torch.nonzero(task_label_ids, as_tuple=False)
            new_prob_matrix[changed_places,
                            0] = 1.0 - prob_matrix[changed_places, 0]

            loss = F.nll_loss(new_prob_matrix, label_ids.view(-1))

            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            iter_co += 1

            # if iter_co % len(train_dataloader) ==0:
            if iter_co % (len(train_dataloader) // 5) == 0:
                '''
                start evaluate on  dev set after this epoch
                '''
                # if n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
                #     model = torch.nn.DataParallel(model)
                model.eval()
                for m in model.modules():
                    if isinstance(m, torch.nn.BatchNorm2d):
                        m.track_running_stats = False
                # logger.info("***** Running evaluation *****")
                # logger.info("  Num examples = %d", len(valid_examples_MNLI))
                # logger.info("  Batch size = %d", args.eval_batch_size)

                dev_acc_sum = 0.0
                for idd, valid_dataloader in enumerate(valid_dataloader_list):
                    task_label = dev_task_label[idd]
                    eval_loss = 0
                    nb_eval_steps = 0
                    preds = []
                    gold_label_ids = []
                    # print('Evaluating...', task_label)
                    # for _, batch in enumerate(tqdm(valid_dataloader, desc=task_names[idd])):
                    for _, batch in enumerate(valid_dataloader):
                        batch = tuple(t.to(device) for t in batch)
                        input_ids, input_mask, segment_ids, label_ids, task_label_ids = batch
                        if task_label == 0:
                            gold_label_ids += list(
                                label_ids.detach().cpu().numpy())
                        else:
                            '''SciTail, RTE'''
                            task_label_ids_list = list(
                                task_label_ids.detach().cpu().numpy())
                            gold_label_batch_fake = list(
                                label_ids.detach().cpu().numpy())
                            for ex_id, label_id in enumerate(
                                    gold_label_batch_fake):
                                if task_label_ids_list[ex_id] == 0:
                                    gold_label_ids.append(label_id)  #0
                                else:
                                    gold_label_ids.append(1)  #1
                        with torch.no_grad():
                            logits = model(input_ids=input_ids,
                                           attention_mask=input_mask,
                                           token_type_ids=None,
                                           labels=None)
                        logits = logits[0]
                        if len(preds) == 0:
                            preds.append(logits.detach().cpu().numpy())
                        else:
                            preds[0] = np.append(preds[0],
                                                 logits.detach().cpu().numpy(),
                                                 axis=0)

                    preds = preds[0]
                    pred_probs = softmax(preds, axis=1)
                    pred_label_ids_3way = np.argmax(pred_probs, axis=1)
                    if task_label == 0:
                        '''3-way tasks MNLI, SNLI, ANLI'''
                        pred_label_ids = pred_label_ids_3way
                    else:
                        '''SciTail, RTE'''
                        pred_label_ids = []
                        for pred_label_i in pred_label_ids_3way:
                            if pred_label_i == 0:
                                pred_label_ids.append(0)
                            else:
                                pred_label_ids.append(1)
                    assert len(pred_label_ids) == len(gold_label_ids)
                    hit_co = 0
                    for k in range(len(pred_label_ids)):
                        if pred_label_ids[k] == gold_label_ids[k]:
                            hit_co += 1
                    test_acc = hit_co / len(gold_label_ids)
                    dev_acc_sum += test_acc
                    print(task_names[idd], ' dev acc:', test_acc)
                '''store the model, because we can test after a max_dev acc reached'''
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                store_transformers_models(
                    model_to_save, tokenizer,
                    '/export/home/Dataset/BERT_pretrained_mine/TrainedModelReminder/',
                    'RoBERTa_on_MNLI_SNLI_SciTail_RTE_ANLI_SpecialToken_Filter_1_epoch_'
                    + str(epoch_i) + '_acc_' + str(dev_acc_sum))
            if Config.gradient_accumulation_steps > 1:
                loss = loss / Config.gradient_accumulation_steps
            loss.backward()

            nb_tr_steps += 1
            tr_mask_acc.update(mask_metric.value(), n=input_ids.size(0))
            tr_sop_acc.update(sop_metric.value(), n=input_ids.size(0))
            tr_loss.update(loss.item(), n=1)
            tr_mask_loss.update(masked_lm_loss.item(), n=1)
            tr_sop_loss.update(next_sentence_loss.item(), n=1)

            if (step + 1) % Config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               Config.max_grad_norm)
                scheduler.step()
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

            if global_step % Config.num_save_steps == 0:
                model_to_save = model.module if hasattr(model,
                                                        'module') else model
                output_model_file = os.path.join(
                    Config.output_dir,
                    'pytorch_model_epoch{}.bin'.format(global_step))
                torch.save(model_to_save.state_dict(), output_model_file)

                # save config
                output_config_file = Config.output_dir + "config.json"
                with open(str(output_config_file), 'w') as f:
                    f.write(model_to_save.config.to_json_string())
Beispiel #30
0
    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size //
                                      args.gradient_accumulation_steps,
                                      shuffle=True,
                                      num_workers=2)

        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)

        # Optimizers
        optimizer = AdamW(model.parameters(), args.lr)

        train_loss = []
        if dev_dataset:
            valid_loss = []
            valid_ind_class_acc = []
        iteration = 0
        for i in range(args.n_epoch):

            model.train()

            total_loss = 0
            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                logits = model(token, mask, type_ids)
                loss = classified_loss(logits, y.long())
                total_loss += loss.item()
                loss = loss / args.gradient_accumulation_steps
                loss.backward()
                # bp and update parameters
                if (global_step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            logger.info('[Epoch {}] Train: train_loss: {}'.format(
                i, total_loss / n_sample))
            logger.info('-' * 30)

            train_loss.append(total_loss / n_sample)
            iteration += 1

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_loss.append(eval_result['loss'])
                valid_ind_class_acc.append(eval_result['ind_class_acc'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(eval_result['accuracy'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_model(model,
                               path=config['model_save_path'],
                               model_name='bert')

                # logger.info(eval_result)

        from utils.visualization import draw_curve
        draw_curve(train_loss, iteration, 'train_loss', args.output_dir)
        if dev_dataset:
            draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir)
            draw_curve(valid_ind_class_acc, iteration,
                       'valid_ind_class_accuracy', args.output_dir)

        if args.patience >= args.n_epoch:
            save_model(model,
                       path=config['model_save_path'],
                       model_name='bert')

        freeze_data['train_loss'] = train_loss
        freeze_data['valid_loss'] = valid_loss