class MTBModel(RelationExtractor):
    def __init__(self, config: dict):
        """
        Matching the Blanks Model.

        Args:
            config: configuration parameters
        """
        super().__init__()
        self.experiment_name = config.get("experiment_name")
        self.transformer = config.get("transformer")
        self.config = config
        self.data_loader = MTBPretrainDataLoader(self.config)
        self.train_len = len(self.data_loader.train_generator)
        logger.info("Loaded %d pre-training samples." % self.train_len)

        self.model = BertModel.from_pretrained(
            model_size=self.transformer,
            pretrained_model_name_or_path=self.transformer,
            force_download=False,
        )

        self.tokenizer = self.data_loader.tokenizer
        self.model.resize_token_embeddings(len(self.tokenizer))
        e1_id = self.tokenizer.convert_tokens_to_ids("[E1]")
        e2_id = self.tokenizer.convert_tokens_to_ids("[E2]")
        if e1_id == e2_id == 1:
            raise ValueError("e1_id == e2_id == 1")

        self.train_on_gpu = torch.cuda.is_available() and config.get(
            "use_gpu", True)
        if self.train_on_gpu:
            logger.info("Train on GPU")
            self.model.cuda()

        self.criterion = MTBLoss(lm_ignore_idx=self.tokenizer.pad_token_id, )
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.01,
            },
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=self.config.get("lr"))
        ovr_steps = (self.config.get("epochs") *
                     len(self.data_loader.train_generator) *
                     self.config.get("max_size") * 2 /
                     self.config.get("batch_size"))
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, ovr_steps // 10, ovr_steps)

        self._start_epoch = 0
        self._best_mtb_bce = 50
        self._train_loss = []
        self._train_lm_acc = []
        self._lm_acc = []
        self._mtb_bce = []
        self.checkpoint_dir = os.path.join("models", "MTB-pretraining",
                                           self.experiment_name,
                                           self.transformer)
        Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)

        self._batch_points_seen = 0
        self._points_seen = 0

    def load_best_model(self, checkpoint_dir: str):
        """
        Loads the current best model in the checkpoint directory.

        Args:
            checkpoint_dir: Checkpoint directory path
        """
        checkpoint = super().load_best_model(checkpoint_dir)
        return (
            checkpoint["epoch"],
            checkpoint["best_mtb_bce"],
            checkpoint["losses_per_epoch"],
            checkpoint["accuracy_per_epoch"],
            checkpoint["lm_acc"],
            checkpoint["blanks_mse"],
        )

    def train(self, **kwargs):
        """
        Runs the training.

        Arg:
            kwargs: Additional Keyword arguments
        """
        save_best_model_only = kwargs.get("save_best_model_only", False)
        results_path = os.path.join(
            "results",
            "MTB-pretraining",
            self.experiment_name,
            self.transformer,
        )
        best_model_path = os.path.join(self.checkpoint_dir,
                                       "best_model.pth.tar")
        resume = self.config.get("resume", False)
        if resume and os.path.exists(best_model_path):
            (
                self._start_epoch,
                self._best_mtb_bce,
                self._train_loss,
                self._train_lm_acc,
                self._lm_acc,
                self._mtb_bce,
            ) = self.load_best_model(self.checkpoint_dir)

        logger.info("Starting training process")
        update_size = len(self.data_loader.train_generator) // 100
        for epoch in range(self._start_epoch, self.config.get("epochs")):
            self._train_epoch(epoch, update_size, save_best_model_only)
            data = self._write_kpis(results_path)
            self._plot_results(data, results_path)
        logger.info("Finished Training.")
        return self.model

    def _plot_results(self, data, save_at):
        fig, ax = plt.subplots(figsize=(20, 20))
        sns.lineplot(x="Epoch", y="Train Loss", ax=ax, data=data, linewidth=4)
        ax.set_title("Training Loss")
        plt.savefig(
            os.path.join(save_at,
                         "train_loss_{0}.png".format(self.transformer)))
        plt.close(fig)

        fig, ax = plt.subplots(figsize=(20, 20))
        sns.lineplot(x="Epoch",
                     y="Val MTB Loss",
                     ax=ax,
                     data=data,
                     linewidth=4)
        ax.set_title("Val MTB Binary Cross Entropy")
        plt.savefig(
            os.path.join(save_at,
                         "val_mtb_bce_{0}.png".format(self.transformer)))
        plt.close(fig)

        tmp = data[["Epoch", "Train LM Accuracy",
                    "Val LM Accuracy"]].melt(id_vars="Epoch",
                                             var_name="Set",
                                             value_name="LM Accuracy")
        fig, ax = plt.subplots(figsize=(20, 20))
        sns.lineplot(
            x="Epoch",
            y="LM Accuracy",
            hue="Set",
            ax=ax,
            data=tmp,
            linewidth=4,
        )
        ax.set_title("LM Accuracy")
        plt.savefig(
            os.path.join(save_at, "lm_acc_{0}.png".format(self.transformer)))
        plt.close(fig)

    def _write_kpis(self, results_path):
        Path(results_path).mkdir(parents=True, exist_ok=True)
        data = pd.DataFrame({
            "Epoch": np.arange(len(self._train_loss)),
            "Train Loss": self._train_loss,
            "Train LM Accuracy": self._train_lm_acc,
            "Val LM Accuracy": self._lm_acc,
            "Val MTB Loss": self._mtb_bce,
        })
        data.to_csv(
            os.path.join(results_path,
                         "kpis_{0}.csv".format(self.transformer)),
            index=False,
        )
        return data

    def _train_epoch(self,
                     epoch,
                     update_size,
                     save_best_model_only: bool = False):
        start_time = super()._train_epoch(epoch)

        train_lm_acc, train_loss, train_mtb_bce = [], [], []

        for i, data in enumerate(tqdm(self.data_loader.train_generator)):
            sequence, masked_label, e1_e2_start, blank_labels = data
            if sequence.shape[1] > 70:
                continue
            res = self._train_on_batch(sequence, masked_label, e1_e2_start,
                                       blank_labels)
            if res[0]:
                train_loss.append(res[0])
                train_lm_acc.append(res[1])
                train_mtb_bce.append(res[2])
            if (i % update_size) == (update_size - 1):
                logger.info(
                    f"{i+1}/{self.train_len} pools: - " +
                    f"Train loss: {np.mean(train_loss)}, " +
                    f"Train LM accuracy: {np.mean(train_lm_acc)}, " +
                    f"Train MTB Binary Cross Entropy {np.mean(train_mtb_bce)}")

        self._train_loss.append(np.mean(train_loss))
        self._train_lm_acc.append(np.mean(train_lm_acc))

        self.on_epoch_end(epoch, self._mtb_bce, self._best_mtb_bce,
                          save_best_model_only)

        logger.info(
            f"Epoch finished, took {time.time() - start_time} seconds!")
        logger.info(f"Train Loss: {self._train_loss[-1]}!")
        logger.info(f"Train LM Accuracy: {self._train_lm_acc[-1]}!")
        logger.info(f"Validation LM Accuracy: {self._lm_acc[-1]}!")
        logger.info(
            f"Validation MTB Binary Cross Entropy: {self._mtb_bce[-1]}!")

    def on_epoch_end(self,
                     epoch,
                     benchmark,
                     baseline,
                     save_best_model_only: bool = False):
        """
        Function to run at the end of an epoch.

        Runs the evaluation method, increments the scheduler, sets a new baseline and appends the KPIS.ä

        Args:
            epoch: Current epoch
            benchmark: List of benchmark results
            baseline: Current baseline. Best model performance so far
            save_best_model_only: Whether to only save the best model so far
                or all of them
        """
        eval_result = super().on_epoch_end(epoch, benchmark, baseline)
        self._best_mtb_bce = (eval_result[1]
                              if eval_result[1] < self._best_mtb_bce else
                              self._best_mtb_bce)
        self._mtb_bce.append(eval_result[1])
        self._lm_acc.append(eval_result[0])
        super().save_on_epoch_end(self._mtb_bce, self._best_mtb_bce, epoch,
                                  save_best_model_only)

    def _train_on_batch(
        self,
        sequence,
        mskd_label,
        e1_e2_start,
        blank_labels,
    ):
        mskd_label = mskd_label[(mskd_label != self.tokenizer.pad_token_id)]
        if mskd_label.shape[0] == 0:
            return None, None, None
        if self.train_on_gpu:
            mskd_label = mskd_label.cuda()
        blanks_logits, lm_logits = self._get_logits(e1_e2_start, sequence)
        loss = self.criterion(
            lm_logits,
            blanks_logits,
            mskd_label,
            blank_labels,
        )
        loss_p = loss.item()
        loss = loss / self.config.get("batch_size")
        loss.backward()
        self._batch_points_seen += len(sequence)
        self._points_seen += len(sequence)
        if self._batch_points_seen > self.config.get("batch_size"):
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()
            self._batch_points_seen = 0
        train_metrics = self.calculate_metrics(
            lm_logits,
            blanks_logits,
            mskd_label,
            blank_labels,
        )
        return loss_p / len(sequence), train_metrics[0], train_metrics[1]

    def _save_model(self, path, epoch, best_model: bool = False):
        if best_model:
            model_path = os.path.join(path, "best_model.pth.tar")
        else:
            model_path = os.path.join(
                path, "checkpoint_epoch_{0}.pth.tar").format(epoch + 1)
        torch.save(
            {
                "epoch": epoch + 1,
                "state_dict": self.model.state_dict(),
                "tokenizer": self.tokenizer,
                "best_mtb_bce": self._best_mtb_bce,
                "optimizer": self.optimizer.state_dict(),
                "scheduler": self.scheduler.state_dict(),
                "losses_per_epoch": self._train_loss,
                "accuracy_per_epoch": self._train_lm_acc,
                "lm_acc": self._lm_acc,
                "blanks_mse": self._mtb_bce,
            },
            model_path,
        )

    def _get_logits(self, e1_e2_start, x):
        attention_mask = (x != self.tokenizer.pad_token_id).float()
        token_type_ids = torch.zeros((x.shape[0], x.shape[1])).long()
        if self.train_on_gpu:
            x = x.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
        blanks_logits, lm_logits = self.model(
            x,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            e1_e2_start=e1_e2_start,
        )
        lm_logits = lm_logits[(x == self.tokenizer.mask_token_id)]
        return blanks_logits, lm_logits

    def evaluate(self) -> tuple:
        """
        Run the validation generator and return performance metrics.
        """
        total_loss = []
        lm_acc = []
        blanks_mse = []

        self.model.eval()
        with torch.no_grad():
            for data in self.data_loader.validation_generator:
                (x, masked_label, e1_e2_start, blank_labels) = data
                masked_label = masked_label[(masked_label !=
                                             self.tokenizer.pad_token_id)]
                if masked_label.shape[0] == 0:
                    continue
                if self.train_on_gpu:
                    masked_label = masked_label.cuda()
                blanks_logits, lm_logits = self._get_logits(e1_e2_start, x)

                loss = self.criterion(
                    lm_logits,
                    blanks_logits,
                    masked_label,
                    blank_labels,
                )

                total_loss += loss.cpu().numpy()
                eval_result = self.calculate_metrics(lm_logits, blanks_logits,
                                                     masked_label,
                                                     blank_labels)
                lm_acc += [eval_result[0]]
                blanks_mse += [eval_result[1]]
        self.model.train()
        return (
            np.mean(lm_acc),
            sum(b for b in blanks_mse if b != 1) /
            len([b for b in blanks_mse if b != 1]),
        )

    def calculate_metrics(
        self,
        lm_logits,
        blanks_logits,
        masked_for_pred,
        blank_labels,
    ) -> tuple:
        """
        Calculates the performance metrics of the MTB model.

        Args:
            lm_logits: Language model Logits per word in vocabulary
            blanks_logits: Blank logits
            masked_for_pred: List of marked tokens
            blank_labels: Blank labels
        """
        lm_logits_pred_ids = torch.softmax(lm_logits, dim=-1).max(1)[1]
        lm_accuracy = ((lm_logits_pred_ids == masked_for_pred).sum().float() /
                       len(masked_for_pred)).item()

        pos_idxs = np.where(blank_labels == 1)[0]
        neg_idxs = np.where(blank_labels == 0)[0]

        if len(pos_idxs) > 1:
            # positives
            pos_logits = []
            for pos1, pos2 in combinations(pos_idxs, 2):
                pos_logits.append(
                    self._get_mtb_logits(blanks_logits[pos1, :],
                                         blanks_logits[pos2, :]))
            pos_logits = torch.stack(pos_logits, dim=0)
            pos_labels = [1.0 for _ in range(pos_logits.shape[0])]
        else:
            pos_logits, pos_labels = torch.FloatTensor([]), []
            if blanks_logits.is_cuda:
                pos_logits = pos_logits.cuda()

        # negatives
        neg_logits = []
        for pos_idx in pos_idxs:
            for neg_idx in neg_idxs:
                neg_logits.append(
                    MTBModel._get_mtb_logits(blanks_logits[pos_idx, :],
                                             blanks_logits[neg_idx, :]))
        neg_logits = torch.stack(neg_logits, dim=0)
        neg_labels = [0.0 for _ in range(neg_logits.shape[0])]

        blank_labels = torch.FloatTensor(pos_labels + neg_labels)
        blank_pred = torch.cat([pos_logits, neg_logits], dim=0)
        bce = nn.BCEWithLogitsLoss(reduction="mean")(
            blank_pred.detach().cpu(), blank_labels.detach().cpu())

        return lm_accuracy, bce.numpy()

    @classmethod
    def _get_mtb_logits(cls, f1_vec, f2_vec):
        factor = 1 / (torch.norm(f1_vec) * torch.norm(f2_vec))
        return factor * torch.dot(f1_vec, f2_vec)
Beispiel #2
0
class Framework(object):
    """A framework wrapping the Relational Graph Extraction model. This framework allows to train, predict, evaluate,
    saving and loading the model with a single line of code.
    """
    def __init__(self, **config):
        super().__init__()

        self.config = config

        self.grad_acc = self.config[
            'grad_acc'] if 'grad_acc' in self.config else 1
        self.device = torch.device(self.config['device'])
        if isinstance(self.config['model'], str):
            self.model = MODELS[self.config['model']](**self.config)
        else:
            self.model = self.config['model']

        self.class_weights = torch.tensor(self.config['class_weights']).float(
        ) if 'class_weights' in self.config else torch.ones(
            self.config['n_rel'])
        if 'lambda' in self.config:
            self.class_weights[0] = self.config['lambda']
        self.loss_fn = nn.CrossEntropyLoss(weight=self.class_weights.to(
            self.device),
                                           reduction='mean')
        if self.config['optimizer'] == 'SGD':
            self.optimizer = torch.optim.SGD(
                self.model.get_parameters(self.config.get('l2', .01)),
                lr=self.config['lr'],
                momentum=self.config.get('momentum', 0),
                nesterov=self.config.get('nesterov', False))
        elif self.config['optimizer'] == 'Adam':
            self.optimizer = torch.optim.Adam(self.model.get_parameters(
                self.config.get('l2', .01)),
                                              lr=self.config['lr'])
        elif self.config['optimizer'] == 'AdamW':
            self.optimizer = AdamW(self.model.get_parameters(
                self.config.get('l2', .01)),
                                   lr=self.config['lr'])
        else:
            raise Exception('The optimizer must be SGD, Adam or AdamW')

    def _train_step(self, dataset, epoch, scheduler=None):
        print("Training:")
        self.model.train()

        total_loss = 0
        predictions, labels, positions = [], [], []
        precision = recall = fscore = 0.0
        progress = tqdm(
            enumerate(dataset),
            desc=
            f"Epoch: {epoch} - Loss: {0.0} - P/R/F: {precision}/{recall}/{fscore}",
            total=len(dataset))
        for i, batch in progress:
            # uncompress the batch
            seq, mask, ent, label = batch
            seq = seq.to(self.device)
            mask = mask.to(self.device)
            ent = ent.to(self.device)
            label = label.to(self.device)

            #self.optimizer.zero_grad()
            output = self.model(seq, mask, ent)
            loss = self.loss_fn(output, label)
            total_loss += loss.item()

            if self.config['half']:
                with amp.scale_loss(loss, self.optimizer) as scale_loss:
                    scale_loss.backward()
            else:
                loss.backward()

            if (i + 1) % self.grad_acc == 0:
                if self.config.get('grad_clip', False):
                    if self.config['half']:
                        nn.utils.clip_grad_norm_(
                            amp.master_params(self.optimizer),
                            self.config['grad_clip'])
                    else:
                        nn.utils.clip_grad_norm_(self.model.parameters(),
                                                 self.config['grad_clip'])

                self.optimizer.step()
                self.model.zero_grad()
                if scheduler:
                    scheduler.step()

            # Evaluate results
            pre, lab, pos = dataset.evaluate(
                i,
                output.detach().numpy() if self.config['device'] is 'cpu' else
                output.detach().cpu().numpy())

            predictions.extend(pre)
            labels.extend(lab)
            positions.extend(pos)

            if (i + 1) % 10 == 0:
                precision, recall, fscore, _ = precision_recall_fscore_support(
                    np.array(labels),
                    np.array(predictions),
                    average='micro',
                    labels=list(range(1, self.model.n_rel)))

            progress.set_description(
                f"Epoch: {epoch} - Loss: {total_loss/(i+1):.3f} - P/R/F: {precision:.2f}/{recall:.2f}/{fscore:.2f}"
            )

        # For last iteration
        #self.optimizer.step()
        #self.optimizer.zero_grad()

        predictions, labels = np.array(predictions), np.array(labels)
        precision, recall, fscore, _ = precision_recall_fscore_support(
            labels,
            predictions,
            average='micro',
            labels=list(range(1, self.model.n_rel)))
        print(
            f"Precision: {precision:.3f} - Recall: {recall:.3f} - F-Score: {fscore:.3f}"
        )
        precision, recall, fscore, _ = precision_recall_fscore_support(
            labels, predictions, average='micro')
        print(
            f"[with NO-RELATION] Precision: {precision:.3f} - Recall: {recall:.3f} - F-Score: {fscore:.3f}"
        )

        return total_loss / (i + 1)

    def _val_step(self, dataset, epoch):
        print("Validating:")
        self.model.eval()

        predictions, labels, positions = [], [], []
        total_loss = 0
        with torch.no_grad():
            progress = tqdm(enumerate(dataset),
                            desc=f"Epoch: {epoch} - Loss: {0.0}",
                            total=len(dataset))
            for i, batch in progress:
                # uncompress the batch
                seq, mask, ent, label = batch
                seq = seq.to(self.device)
                mask = mask.to(self.device)
                ent = ent.to(self.device)
                label = label.to(self.device)

                output = self.model(seq, mask, ent)
                loss = self.loss_fn(output, label)
                total_loss += loss.item()

                # Evaluate results
                pre, lab, pos = dataset.evaluate(
                    i,
                    output.detach().numpy() if self.config['device'] is 'cpu'
                    else output.detach().cpu().numpy())

                predictions.extend(pre)
                labels.extend(lab)
                positions.extend(pos)

                progress.set_description(
                    f"Epoch: {epoch} - Loss: {total_loss/(i+1):.3f}")

        predictions, labels = np.array(predictions), np.array(labels)
        precision, recall, fscore, _ = precision_recall_fscore_support(
            labels,
            predictions,
            average='micro',
            labels=list(range(1, self.model.n_rel)))
        print(
            f"Precision: {precision:.3f} - Recall: {recall:.3f} - F-Score: {fscore:.3f}"
        )
        noprecision, norecall, nofscore, _ = precision_recall_fscore_support(
            labels, predictions, average='micro')
        print(
            f"[with NO-RELATION] Precision: {noprecision:.3f} - Recall: {norecall:.3f} - F-Score: {nofscore:.3f}"
        )

        return total_loss / (i + 1), precision, recall, fscore

    def _save_checkpoint(self, dataset, epoch, loss, val_loss):
        print(f"Saving checkpoint ({dataset.name}.pth) ...")
        PATH = os.path.join('checkpoints', f"{dataset.name}.pth")
        config_PATH = os.path.join('checkpoints',
                                   f"{dataset.name}_config.json")
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': loss,
                'val_loss': val_loss
            }, PATH)
        with open(config_PATH, 'wt') as f:
            json.dump(self.config, f)

    def _load_checkpoint(self, PATH: str, config_PATH: str):
        checkpoint = torch.load(PATH)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']

        with open(config_PATH, 'rt') as f:
            self.config = json.load(f)

        return epoch, loss

    def fit(self,
            dataset,
            validation=True,
            batch_size=1,
            patience=3,
            delta=0.):
        """ Fits the model to the given dataset.

        Usage:
        ``` y
        >>> rge = Framework(**config)
        >>> rge.fit(train_data)
        """
        self.model.to(self.device)
        train_data = dataset.get_train(batch_size)

        if self.config['half']:
            self.model, self.optimizer = amp.initialize(
                self.model,
                self.optimizer,
                opt_level='O2',
                keep_batchnorm_fp32=True)

        if self.config['linear_scheduler']:
            num_training_steps = int(
                len(train_data) // self.grad_acc * self.config['epochs'])
            scheduler = get_linear_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=self.config.get('warmup_steps', 0),
                num_training_steps=num_training_steps)
        else:
            scheduler = None

        early_stopping = EarlyStopping(patience, delta, self._save_checkpoint)

        for epoch in range(self.config['epochs']):
            self.optimizer.zero_grad()
            loss = self._train_step(train_data, epoch, scheduler=scheduler)
            if validation:
                val_loss, _, _, _ = self._val_step(dataset.get_val(batch_size),
                                                   epoch)
                if early_stopping(val_loss,
                                  dataset=dataset,
                                  epoch=epoch,
                                  loss=loss):
                    break

        # Recover the best epoch
        path = os.path.join("checkpoints", f"{dataset.name}.pth")
        config_path = os.path.join("checkpoints",
                                   f"{dataset.name}_config.json")
        _, _ = self._load_checkpoint(path, config_path)

    def predict(self, dataset, return_proba=False) -> torch.Tensor:
        """ Predicts the relations graph for the given dataset.
        """
        self.model.to(self.device)
        self.model.eval()

        predictions, instances = [], []
        with torch.no_grad():
            progress = tqdm(enumerate(dataset), total=len(dataset))
            for i, batch in progress:
                # uncompress the batch
                seq, mask, ent, label = batch
                seq = seq.to(self.device)
                mask = mask.to(self.device)
                ent = ent.to(self.device)
                label = label.to(self.device)

                output = self.model(seq, mask, ent)
                if not return_proba:
                    pred = np.argmax(output.detach().cpu().numpy(),
                                     axis=1).tolist()
                else:
                    pred = output.detach().cpu().numpy().tolist()
                inst = dataset.get_instances(i)

                predictions.extend(pred)
                instances.extend(inst)

        return predictions, instances

    def evaluate(self, dataset: Dataset, batch_size=1) -> torch.Tensor:
        """ Evaluates the model given for the given dataset.
        """
        loss, precision, recall, fscore = self._val_step(
            dataset.get_val(batch_size), 0)
        return loss, precision, recall, fscore

    def save_model(self, path: str):
        """ Saves the model to a file.

        Usage:
        ``` 
        >>> rge = Framework(**config)
        >>> rge.fit(train_data)

        >>> rge.save_model("path/to/file")
        ```

        TODO
        """
        self.model.save_pretrained(path)
        with open(f"{path}/fine_tunning.config.json", 'wt') as f:
            json.dump(self.config, f, indent=4)

    @classmethod
    def load_model(cls,
                   path: str,
                   config_path: str = None,
                   from_checkpoint=False):
        """ Loads the model from a file.

        Args:
            path: str Path to the file that stores the model.

        Returns:
            Framework instance with the loaded model.

        Usage:
        ```
        >>> rge = Framework.load_model("path/to/model")
        ```

        TODO
        """
        if not from_checkpoint:
            config_path = path + '/fine_tunning.config.json'
            with open(config_path) as f:
                config = json.load(f)
            config['pretrained_model'] = path
            rge = cls(**config)

        else:
            if config_path is None:
                raise Exception(
                    'Loading the model from a checkpoint requires config_path argument.'
                )
            with open(config_path) as f:
                config = json.load(f)
            rge = cls(**config)
            rge._load_checkpoint(path, config_path)

        return rge
def model_train_validate_test(train_df, dev_df, test_df, target_dir, 
         max_seq_len=64,
         num_labels=2,
         epochs=10,
         batch_size=32,
         lr=2e-05,
         patience=1,
         max_grad_norm=10.0,
         if_save_model=True,
         checkpoint=None):

    bertmodel = DistilBertModel(requires_grad = True, num_labels = num_labels)
    tokenizer = bertmodel.tokenizer
    
    print(20 * "=", " Preparing for training ", 20 * "=")
    # 保存模型的路径,没有则创建文件夹
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)
    # -------------------- Data loading ------------------- #
    print("\t* Loading training data...")
    train_data = DataPrecessForSentence(tokenizer, train_df, max_seq_len)
    train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)

    print("\t* Loading validation data...")
    dev_data = DataPrecessForSentence(tokenizer,dev_df, max_seq_len)
    dev_loader = DataLoader(dev_data, shuffle=True, batch_size=batch_size)
    
    print("\t* Loading test data...")
    test_data = DataPrecessForSentence(tokenizer,test_df, max_seq_len) 
    test_loader = DataLoader(test_data, shuffle=False, batch_size=batch_size)
    # -------------------- Model definition ------------------- #
    print("\t* Building model...")
    device = torch.device("cuda")
    model = bertmodel.to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')
    # -------------------- Preparation for training  ------------------- #
    # 待优化的参数
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
            {
                    'params':[p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                    'weight_decay':0.01
            },
            {
                    'params':[p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                    'weight_decay':0.0
            }
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)

    # 当网络的评价指标不在提升的时候,可以通过降低网络的学习率来提高网络性能
    # warmup_steps = math.ceil(len(train_loader) * epochs * 0.1)
    # total_steps = len(train_loader) * epochs
    # scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.85, patience=2, verbose=True)

    best_score = 0.0
    start_epoch = 1
    # Data for loss curves plot
    epochs_count = []
    train_losses = []
    train_accuracies = []
    valid_losses = []
    valid_accuracies = []
    # Continuing training from a checkpoint if one was given as argument
    if checkpoint:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint["epoch"] + 1
        best_score = checkpoint["best_score"]
        print("\t* Training will continue on existing model from epoch {}...".format(start_epoch))
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epochs_count = checkpoint["epochs_count"]
        train_losses = checkpoint["train_losses"]
        train_accuracy = checkpoint["train_accuracy"]
        valid_losses = checkpoint["valid_losses"]
        valid_accuracy = checkpoint["valid_accuracy"]
     # Compute loss and accuracy before starting (or resuming) training.
    _, valid_loss, valid_accuracy, _, = validate(model, dev_loader)
    print("\n* Validation loss before training: {:.4f}, accuracy: {:.4f}%".format(valid_loss, (valid_accuracy*100)))
    # -------------------- Training epochs ------------------- #
    print("\n", 20 * "=", "Training roberta model on device: {}".format(device), 20 * "=")
    patience_counter = 0
    for epoch in range(start_epoch, epochs + 1):
        epochs_count.append(epoch)

        print("* Training epoch {}:".format(epoch))
        epoch_time, epoch_loss, epoch_accuracy = train(model, train_loader, optimizer, epoch, max_grad_norm)
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_accuracy)
        
        print("-> Training time: {:.4f}s, loss = {:.4f}, accuracy: {:.4f}%".format(epoch_time, epoch_loss, (epoch_accuracy*100)))
        
        print("* Validation for epoch {}:".format(epoch))
        epoch_time, epoch_loss, epoch_accuracy, _, = validate(model, dev_loader)
        valid_losses.append(epoch_loss)
        valid_accuracies.append(epoch_accuracy)
        print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n"
              .format(epoch_time, epoch_loss, (epoch_accuracy*100)))
        
        # Update the optimizer's learning rate with the scheduler.
        # scheduler.step()
        scheduler.step(epoch_accuracy)
        # Early stopping on validation accuracy.
        if epoch_accuracy < best_score:
            patience_counter += 1
        else:
            best_score = epoch_accuracy
            patience_counter = 0
            
            if (if_save_model):
                torch.save({"epoch": epoch, 
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "best_score": best_score, # 验证集上的最优准确率
                        "epochs_count": epochs_count,
                        "train_losses": train_losses,
                        "train_accuracy": train_accuracies,
                        "valid_losses": valid_losses,
                        "valid_accuracy": valid_accuracies
                        },
                        os.path.join(target_dir, "best.pth.tar"))
                print("save model succesfully!\n")
            
            print("* Test for epoch {}:".format(epoch))
            _, _, test_accuracy, predictions = validate(model, test_loader)
            print("Test accuracy: {:.4f}%\n".format(test_accuracy))
            test_prediction = pd.DataFrame({'prediction':predictions})
            test_prediction.to_csv(os.path.join(target_dir,"test_prediction.csv"), index=False)
             
        if patience_counter >= patience:
            print("-> Early stopping: patience limit reached, stopping...")
            break
Beispiel #4
0
class Trainer(object):
    def __init__(self, proto, stage="train"):
        # model config
        model_cfg = proto["model"]
        model_name = model_cfg["name"]
        self.model_name = model_name

        # dataset config
        data_cfg = proto["data"]
        train_data_path = data_cfg.get("train_path", None)
        val_data_path = data_cfg.get("val_path", None)
        pad = data_cfg.get("pad", 32)
        train_bs = data_cfg.get("train_batch_size", None)
        val_bs = data_cfg.get("val_batch_size", None)
        self.val_bs = val_bs
        self.skip_first = data_cfg.get("skip_first", False)
        self.delimiter = data_cfg.get("delimiter", "\t")

        # assorted config
        optim_cfg = proto.get("optimizer", {"lr": 0.00003})
        sched_cfg = proto.get("schedulers", None)
        loss = proto.get("loss", "CE")
        self.device = proto.get("device", None)

        model_cfg.pop("name")

        if torch.cuda.is_available() and self.device is not None:
            print("Using device: %d." % self.device)
            self.device = torch.device(self.device)
            self.gpu = True
        else:
            print("Using cpu device.")
            self.device = torch.device("cpu")
            self.gpu = False

        if stage == "train":
            if train_data_path is None or val_data_path is None:
                raise ValueError("Please specify both train and val data path.")
            if train_bs is None or val_bs is None:
                raise ValueError("Please specify both train and val batch size.")
            # loading model
            self.model = fetch_nn(model_name)(**model_cfg)
            self.model = self.model.cuda(self.device)

            # loading dataset and converting into dataloader
            self.train_data = ChineseTextSet(
                path=train_data_path, tokenizer=self.model.tokenizer, pad=pad,
                delimiter=self.delimiter, skip_first=self.skip_first)
            self.train_loader = DataLoader(
                self.train_data, train_bs, shuffle=True, num_workers=4)
            self.val_data = ChineseTextSet(
                path=val_data_path, tokenizer=self.model.tokenizer, pad=pad,
                delimiter=self.delimiter, skip_first=self.skip_first)
            self.val_loader = DataLoader(
                self.val_data, val_bs, shuffle=True, num_workers=4)

            time_format = "%Y-%m-%d...%H.%M.%S"
            id = time.strftime(time_format, time.localtime(time.time()))
            self.record_path = os.path.join(arg.record, model_name, id)

            os.makedirs(self.record_path)
            sys.stdout = Logger(os.path.join(self.record_path, 'records.txt'))
            print("Writing proto file to file directory: %s." % self.record_path)
            yaml.dump(proto, open(os.path.join(self.record_path, 'protocol.yml'), 'w'))

            print("*" * 25, " PROTO BEGINS ", "*" * 25)
            pprint(proto)
            print("*" * 25, " PROTO ENDS ", "*" * 25)

            self.optimizer = AdamW(self.model.parameters(), **optim_cfg)
            self.scheduler = fetch_scheduler(self.optimizer, sched_cfg)

            self.loss = fetch_loss(loss)

            self.best_f1 = 0.0
            self.best_step = 1
            self.start_step = 1

            self.num_steps = proto["num_steps"]
            self.num_epoch = math.ceil(self.num_steps / len(self.train_loader))

            # the number of steps to write down a log
            self.log_steps = proto["log_steps"]
            # the number of steps to validate on val dataset once
            self.val_steps = proto["val_steps"]

            self.f1_meter = AverageMeter()
            self.p_meter = AverageMeter()
            self.r_meter = AverageMeter()
            self.acc_meter = AverageMeter()
            self.loss_meter = AverageMeter()

        if stage == "test":
            if val_data_path is None:
                raise ValueError("Please specify the val data path.")
            if val_bs is None:
                raise ValueError("Please specify the val batch size.")
            id = proto["id"]
            ckpt_fold = proto.get("ckpt_fold", "runs")
            self.record_path = os.path.join(ckpt_fold, model_name, id)
            sys.stdout = Logger(os.path.join(self.record_path, 'tests.txt'))

            config, state_dict, fc_dict = self._load_ckpt(best=True, train=False)
            weights = {"config": config, "state_dict": state_dict}
            # loading trained model using config and state_dict
            self.model = fetch_nn(model_name)(weights=weights)
            # loading the weights for the final fc layer
            self.model.load_state_dict(fc_dict, strict=False)
            # loading model to gpu device if specified
            if self.gpu:
                self.model = self.model.cuda(self.device)

            print("Testing directory: %s." % self.record_path)
            print("*" * 25, " PROTO BEGINS ", "*" * 25)
            pprint(proto)
            print("*" * 25, " PROTO ENDS ", "*" * 25)

            self.val_path = val_data_path
            self.test_data = ChineseTextSet(
                path=val_data_path, tokenizer=self.model.tokenizer, pad=pad,
                delimiter=self.delimiter, skip_first=self.skip_first)
            self.test_loader = DataLoader(
                self.test_data, val_bs, shuffle=True, num_workers=4)

    def _save_ckpt(self, step, best=False, f=None, p=None, r=None):
        save_dir = os.path.join(self.record_path, "best_model.bin" if best else "latest_model.bin")
        torch.save({
            "step": step,
            "f1": f,
            "precision": p,
            "recall": r,
            "best_step": self.best_step,
            "best_f1": self.best_f1,
            "model": self.model.state_dict(),
            "config": self.model.config,
            "optimizer": self.optimizer.state_dict(),
            "schedulers": self.scheduler.state_dict(),
        }, save_dir)

    def _load_ckpt(self, best=False, train=False):
        load_dir = os.path.join(self.record_path, "best_model.bin" if best else "latest_model.bin")
        load_dict = torch.load(load_dir, map_location=self.device)
        self.start_step = load_dict["step"]
        self.best_step = load_dict["best_step"]
        self.best_f1 = load_dict["best_f1"]
        if train:
            self.optimizer.load_state_dict(load_dict["optimizer"])
            self.scheduler.load_state_dict(load_dict["schedulers"])
        print("Loading checkpoint from %s, best step: %d, best f1: %.4f."
              % (load_dir, self.best_step, self.best_f1))
        if not best:
            print("Checkpoint step %s, f1: %.4f, precision: %.4f, recall: %.4f."
                  % (self.start_step, load_dict["f1"],
                     load_dict["precision"], load_dict["recall"]))
        fc_dict = {
            "fc.weight": load_dict["model"]["fc.weight"],
            "fc.bias": load_dict["model"]["fc.bias"]
        }
        return load_dict["config"], load_dict["model"], fc_dict

    def to_cuda(self, *args):
        return [obj.cuda(self.device) for obj in args]

    @staticmethod
    def fixed_randomness():
        random.seed(0)
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)
        torch.cuda.manual_seed_all(0)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    @staticmethod
    def update_metrics(gt, pre, f1_m, p_m, r_m, acc_m):
        f1_value = f1(gt, pre, average="micro")
        f1_m.update(f1_value)
        p_value = precision(gt, pre, average="micro", zero_division=0)
        p_m.update(p_value)
        r_value = recall(gt, pre, average="micro")
        r_m.update(r_value)
        acc_value = accuracy(gt, pre)
        acc_m.update(acc_value)

    def train(self):
        timer = Timer()
        writer = SummaryWriter(self.record_path)
        print("*" * 25, " TRAINING BEGINS ", "*" * 25)
        start_epoch = self.start_step // len(self.train_loader) + 1
        for epoch_idx in range(start_epoch, self.num_epoch + 1):
            self.f1_meter.reset()
            self.p_meter.reset()
            self.r_meter.reset()
            self.acc_meter.reset()
            self.loss_meter.reset()
            self.optimizer.step()
            self.scheduler.step()
            train_generator = tqdm(enumerate(self.train_loader, 1), position=0, leave=True)

            for batch_idx, data in train_generator:
                global_step = (epoch_idx - 1) * len(self.train_loader) + batch_idx
                self.model.train()
                id, label, _, mask = data[:4]
                if self.gpu:
                    id, mask, label = self.to_cuda(id, mask, label)
                pre = self.model((id, mask))
                loss = self.loss(pre, label)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                lbl = label.cpu().numpy()
                yp = pre.argmax(1).cpu().numpy()
                self.update_metrics(
                    lbl, yp, self.f1_meter, self.p_meter,
                    self.r_meter, self.acc_meter
                )
                self.loss_meter.update(loss.item())

                if global_step % self.log_steps == 0 and writer is not None:
                    writer.add_scalar("train/f1", self.f1_meter.avg, global_step)
                    writer.add_scalar("train/loss", self.loss_meter.avg, global_step)
                    writer.add_scalar("train/lr", self.scheduler.get_lr()[0], global_step)

                train_generator.set_description(
                    "Train Epoch %d (%d/%d), "
                    "Global Step %d, Loss %.4f, f1 %.4f, p %.4f, r %.4f, acc %.4f, LR %.6f" % (
                        epoch_idx, batch_idx, len(self.train_loader), global_step,
                        self.loss_meter.avg, self.f1_meter.avg,
                        self.p_meter.avg, self.r_meter.avg,
                        self.acc_meter.avg,
                        self.scheduler.get_lr()[0]
                    )
                )

                # validating process
                if global_step % self.val_steps == 0:
                    print()
                    self.validate(epoch_idx, global_step, timer, writer)

                # when num_steps has been set and the training process will
                # be stopped earlier than the specified num_epochs, then stop.
                if self.num_steps is not None and global_step == self.num_steps:
                    if writer is not None:
                        writer.close()
                    print()
                    print("*" * 25, " TRAINING ENDS ", "*" * 25)
                    return

            train_generator.close()
            print()
        writer.close()
        print("*" * 25, " TRAINING ENDS ", "*" * 25)

    def validate(self, epoch, step, timer, writer):
        with torch.no_grad():
            f1_meter = AverageMeter()
            p_meter = AverageMeter()
            r_meter = AverageMeter()
            acc_meter = AverageMeter()
            loss_meter = AverageMeter()
            val_generator = tqdm(enumerate(self.val_loader, 1), position=0, leave=True)
            for val_idx, data in val_generator:
                self.model.eval()
                id, label, _, mask = data[:4]
                if self.gpu:
                    id, mask, label = self.to_cuda(id, mask, label)
                pre = self.model((id, mask))
                loss = self.loss(pre, label)

                lbl = label.cpu().numpy()
                yp = pre.argmax(1).cpu().numpy()
                self.update_metrics(lbl, yp, f1_meter, p_meter, r_meter, acc_meter)
                loss_meter.update(loss.item())

                val_generator.set_description(
                    "Eval Epoch %d (%d/%d), Global Step %d, Loss %.4f, "
                    "f1 %.4f, p %.4f, r %.4f, acc %.4f" % (
                        epoch, val_idx, len(self.val_loader), step,
                        loss_meter.avg, f1_meter.avg,
                        p_meter.avg, r_meter.avg, acc_meter.avg
                    )
                )

            print("Eval Epoch %d, f1 %.4f" % (epoch, f1_meter.avg))
            if writer is not None:
                writer.add_scalar("val/loss", loss_meter.avg, step)
                writer.add_scalar("val/f1", f1_meter.avg, step)
                writer.add_scalar("val/precision", p_meter.avg, step)
                writer.add_scalar("val/recall", r_meter.avg, step)
                writer.add_scalar("val/acc", acc_meter.avg, step)
            if f1_meter.avg > self.best_f1:
                self.best_f1 = f1_meter.avg
                self.best_step = step
                self._save_ckpt(step, best=True)
            print("Best Step %d, Best f1 %.4f, Running Time: %s, Estimated Time: %s" % (
                self.best_step, self.best_f1, timer.measure(), timer.measure(step / self.num_steps)
            ))
            self._save_ckpt(step, best=False, f=f1_meter.avg, p=p_meter.avg, r=r_meter.avg)

    def test(self):
        # t_idx = random.randint(0, self.val_bs)
        t_idx = random.randint(0, 5)
        with torch.no_grad():
            self.fixed_randomness()  # for reproduction

            # for writing the total predictions to disk
            data_idxs = list()
            all_preds = list()

            # for ploting P-R Curve
            predicts = list()
            truths = list()

            # for showing predicted samples
            show_ctxs = list()
            pred_lbls = list()
            targets = list()

            f1_meter = AverageMeter()
            p_meter = AverageMeter()
            r_meter = AverageMeter()
            accuracy_meter = AverageMeter()
            test_generator = tqdm(enumerate(self.test_loader, 1))
            for idx, data in test_generator:
                self.model.eval()
                id, label, _, mask, data_idx = data
                if self.gpu:
                    id, mask, label = self.to_cuda(id, mask, label)
                pre = self.model((id, mask))

                lbl = label.cpu().numpy()
                yp = pre.argmax(1).cpu().numpy()
                self.update_metrics(lbl, yp, f1_meter, p_meter, r_meter, accuracy_meter)

                test_generator.set_description(
                    "Test %d/%d, f1 %.4f, p %.4f, r %.4f, acc %.4f"
                    % (idx, len(self.test_loader), f1_meter.avg,
                       p_meter.avg, r_meter.avg, accuracy_meter.avg)
                )

                data_idxs.append(data_idx.numpy())
                all_preds.append(yp)

                predicts.append(torch.select(pre, dim=1, index=1).cpu().numpy())
                truths.append(lbl)

                # show some of the sample
                ctx = torch.select(id, dim=0, index=t_idx).detach()
                ctx = self.model.tokenizer.convert_ids_to_tokens(ctx)
                ctx = "".join([_ for _ in ctx if _ not in [PAD, CLS]])
                yp = yp[t_idx]
                lbl = lbl[t_idx]

                show_ctxs.append(ctx)
                pred_lbls.append(yp)
                targets.append(lbl)

            print("*" * 25, " SAMPLE BEGINS ", "*" * 25)
            for c, t, l in zip(show_ctxs, targets, pred_lbls):
                print("ctx: ", c, " gt: ", t, " est: ", l)
            print("*" * 25, " SAMPLE ENDS ", "*" * 25)
            print("Test, FINAL f1 %.4f, "
                  "p %.4f, r %.4f, acc %.4f\n" %
                  (f1_meter.avg, p_meter.avg, r_meter.avg, accuracy_meter.avg))

            # output the final results to disk
            data_idxs = np.concatenate(data_idxs, axis=0)
            all_preds = np.concatenate(all_preds, axis=0)
            write_predictions(
                self.val_path, os.path.join(self.record_path, "results.txt"),
                data_idxs, all_preds, delimiter=self.delimiter, skip_first=self.skip_first
            )

            # output the p-r values for future plotting P-R Curve
            predicts = np.concatenate(predicts, axis=0)
            truths = np.concatenate(truths, axis=0)
            values = precision_recall_curve(truths, predicts)
            with open(os.path.join(self.record_path, "pr.values"), "wb") as f:
                pickle.dump(values, f)
            p_value, r_value, _ = values

            # plot P-R Curve if specified
            if arg.image:
                plt.figure()
                plt.plot(
                    p_value, r_value,
                    label="%s (ACC: %.2f, F1: %.2f)"
                          % (self.model_name, accuracy_meter.avg, f1_meter.avg)
                )
                plt.legend(loc="best")
                plt.title("2-Classes P-R curve")
                plt.xlabel("precision")
                plt.ylabel("recall")
                plt.savefig(os.path.join(self.record_path, "P-R.png"))
                plt.show()
Beispiel #5
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default='bert-base-uncased',
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        '--task',
        type=str,
        default=None,
        required=True,
        help="Task code in {hotpot_open, hotpot_distractor, squad, nq}")

    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=378,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=1,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=5,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam. (def: 5e-5)")
    parser.add_argument("--num_train_epochs",
                        default=5.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument('--local_rank', default=-1, type=int)

    # RNN graph retriever-specific parameters
    parser.add_argument("--example_limit", default=None, type=int)

    parser.add_argument("--max_para_num", default=10, type=int)
    parser.add_argument(
        "--neg_chunk",
        default=8,
        type=int,
        help="The chunk size of negative examples during training (to "
        "reduce GPU memory consumption with negative sampling)")
    parser.add_argument(
        "--eval_chunk",
        default=100000,
        type=int,
        help=
        "The chunk size of evaluation examples (to reduce RAM consumption during evaluation)"
    )
    parser.add_argument(
        "--split_chunk",
        default=300,
        type=int,
        help=
        "The chunk size of BERT encoding during inference (to reduce GPU memory consumption)"
    )

    parser.add_argument('--train_file_path',
                        type=str,
                        default=None,
                        help="File path to the training data")
    parser.add_argument('--dev_file_path',
                        type=str,
                        default=None,
                        help="File path to the eval data")

    parser.add_argument('--beam', type=int, default=1, help="Beam size")
    parser.add_argument('--min_select_num',
                        type=int,
                        default=1,
                        help="Minimum number of selected paragraphs")
    parser.add_argument('--max_select_num',
                        type=int,
                        default=3,
                        help="Maximum number of selected paragraphs")
    parser.add_argument(
        "--use_redundant",
        action='store_true',
        help="Whether to use simulated seqs (only for training)")
    parser.add_argument(
        "--use_multiple_redundant",
        action='store_true',
        help="Whether to use multiple simulated seqs (only for training)")
    parser.add_argument(
        '--max_redundant_num',
        type=int,
        default=100000,
        help=
        "Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)"
    )
    parser.add_argument(
        "--no_links",
        action='store_true',
        help=
        "Whether to omit any links (or in other words, only use TF-IDF-based paragraphs)"
    )
    parser.add_argument("--pruning_by_links",
                        action='store_true',
                        help="Whether to do pruning by links (and top 1)")
    parser.add_argument(
        "--expand_links",
        action='store_true',
        help=
        "Whether to expand links with paragraphs in the same article (for NQ)")
    parser.add_argument(
        '--tfidf_limit',
        type=int,
        default=None,
        help=
        "Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)"
    )

    parser.add_argument("--pred_file",
                        default=None,
                        type=str,
                        help="File name to write paragraph selection results")
    parser.add_argument("--tagme",
                        action='store_true',
                        help="Whether to use tagme at inference")
    parser.add_argument(
        '--topk',
        type=int,
        default=2,
        help="Whether to use how many paragraphs from the previous steps")

    parser.add_argument(
        "--model_suffix",
        default=None,
        type=str,
        help="Suffix to load a model file ('pytorch_model_' + suffix +'.bin')")

    parser.add_argument("--db_save_path",
                        default=None,
                        type=str,
                        help="File path to DB")
    parser.add_argument("--fp16", default=False, action='store_true')
    parser.add_argument("--fp16_opt_level", default="O1", type=str)
    parser.add_argument("--do_label",
                        default=False,
                        action='store_true',
                        help="For pre-processing features only.")

    parser.add_argument("--oss_cache_dir", default=None, type=str)
    parser.add_argument("--cache_dir", default=None, type=str)
    parser.add_argument("--dist",
                        default=False,
                        action='store_true',
                        help='use distributed training.')
    parser.add_argument("--save_steps", default=5000, type=int)
    parser.add_argument("--resume", default=None, type=int)
    parser.add_argument("--oss_pretrain", default=None, type=str)
    parser.add_argument("--model_version", default='v1', type=str)
    parser.add_argument("--disable_rnn_layer_norm",
                        default=False,
                        action='store_true')

    args = parser.parse_args()

    if args.dist:
        dist.init_process_group(backend='nccl')
        print(f"local rank: {args.local_rank}")
        print(f"global rank: {dist.get_rank()}")
        print(f"world size: {dist.get_world_size()}")

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
        dist.init_process_group(backend='nccl')

    if args.dist:
        global_rank = dist.get_rank()
        world_size = dist.get_world_size()
        if world_size > 1:
            args.local_rank = global_rank

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if args.train_file_path is not None:
        do_train = True

        if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
            raise ValueError(
                "Output directory ({}) already exists and is not empty.".
                format(args.output_dir))
        if args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir, exist_ok=True)

    elif args.dev_file_path is not None:
        do_train = False

    else:
        raise ValueError(
            'One of train_file_path: {} or dev_file_path: {} must be non-None'.
            format(args.train_file_path, args.dev_file_path))

    processor = DataProcessor()

    # Configurations of the graph retriever
    graph_retriever_config = GraphRetrieverConfig(
        example_limit=args.example_limit,
        task=args.task,
        max_seq_length=args.max_seq_length,
        max_select_num=args.max_select_num,
        max_para_num=args.max_para_num,
        tfidf_limit=args.tfidf_limit,
        train_file_path=args.train_file_path,
        use_redundant=args.use_redundant,
        use_multiple_redundant=args.use_multiple_redundant,
        max_redundant_num=args.max_redundant_num,
        dev_file_path=args.dev_file_path,
        beam=args.beam,
        min_select_num=args.min_select_num,
        no_links=args.no_links,
        pruning_by_links=args.pruning_by_links,
        expand_links=args.expand_links,
        eval_chunk=args.eval_chunk,
        tagme=args.tagme,
        topk=args.topk,
        db_save_path=args.db_save_path,
        disable_rnn_layer_norm=args.disable_rnn_layer_norm)

    logger.info(graph_retriever_config)
    logger.info(args)

    tokenizer = AutoTokenizer.from_pretrained(args.bert_model)

    if args.model_version == 'roberta':
        from modeling_graph_retriever_roberta import RobertaForGraphRetriever
    elif args.model_version == 'v3':
        from modeling_graph_retriever_roberta import RobertaForGraphRetrieverIterV3 as RobertaForGraphRetriever
    else:
        raise RuntimeError()

    ##############################
    # Training                   #
    ##############################
    if do_train:
        _model_state_dict = None
        if args.oss_pretrain is not None:
            _model_state_dict = torch.load(load_pretrain_from_oss(
                args.oss_pretrain),
                                           map_location='cpu')
            logger.info(f"Loaded pretrained model from {args.oss_pretrain}")

        if args.resume is not None:
            _model_state_dict = torch.load(load_buffer_from_oss(
                os.path.join(args.oss_cache_dir,
                             f"pytorch_model_{args.resume}.bin")),
                                           map_location='cpu')

        model = RobertaForGraphRetriever.from_pretrained(
            args.bert_model,
            graph_retriever_config=graph_retriever_config,
            state_dict=_model_state_dict)

        model.to(device)

        global_step = 0

        POSITIVE = 1.0
        NEGATIVE = 0.0

        _cache_file_name = f"cache_roberta_train_{args.max_seq_length}_{args.max_para_num}"
        _examples_cache_file_name = f"examples_{_cache_file_name}"
        _features_cache_file_name = f"features_{_cache_file_name}"

        # Load training examples
        logger.info(f"Loading training examples and features.")
        try:
            if args.cache_dir is not None and os.path.exists(
                    os.path.join(args.cache_dir, _features_cache_file_name)):
                logger.info(
                    f"Loading pre-processed features from {os.path.join(args.cache_dir, _features_cache_file_name)}"
                )
                train_features = torch.load(
                    os.path.join(args.cache_dir, _features_cache_file_name))
            else:
                # train_examples = torch.load(load_buffer_from_oss(os.path.join(oss_features_cache_dir,
                #                                                               _examples_cache_file_name)))
                train_features = torch.load(
                    load_buffer_from_oss(
                        os.path.join(oss_features_cache_dir,
                                     _features_cache_file_name)))
                logger.info(
                    f"Pre-processed features are loaded from oss: "
                    f"{os.path.join(oss_features_cache_dir, _features_cache_file_name)}"
                )
        except:
            train_examples = processor.get_train_examples(
                graph_retriever_config)
            train_features = convert_examples_to_features(
                train_examples,
                args.max_seq_length,
                args.max_para_num,
                graph_retriever_config,
                tokenizer,
                train=True)
            logger.info(
                f"Saving pre-processed features into oss: {oss_features_cache_dir}"
            )
            torch_save_to_oss(
                train_examples,
                os.path.join(oss_features_cache_dir,
                             _examples_cache_file_name))
            torch_save_to_oss(
                train_features,
                os.path.join(oss_features_cache_dir,
                             _features_cache_file_name))

        if args.do_label:
            logger.info("Finished.")
            return

        # len(train_examples) and len(train_features) can be different, depending on the redundant setting
        num_train_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

        # Prepare optimizer
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'layer_norm']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        t_total = num_train_steps
        if args.local_rank != -1:
            t_total = t_total // dist.get_world_size()

        optimizer = AdamW(optimizer_grouped_parameters,
                          betas=(0.9, 0.98),
                          lr=args.learning_rate)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, int(t_total * args.warmup_proportion), t_total)

        logger.info(optimizer)
        if args.fp16:
            from apex import amp
            amp.register_half_function(torch, "einsum")

            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.fp16_opt_level)

        if args.local_rank != -1:
            if args.fp16_opt_level == 'O2':
                try:
                    import apex
                    model = apex.parallel.DistributedDataParallel(
                        model, delay_allreduce=True)
                except ImportError:
                    model = torch.nn.parallel.DistributedDataParallel(
                        model, find_unused_parameters=True)
            else:
                model = torch.nn.parallel.DistributedDataParallel(
                    model, find_unused_parameters=True)

        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        if args.resume is not None:
            _amp_state_dict = os.path.join(args.oss_cache_dir,
                                           f"amp_{args.resume}.bin")
            _optimizer_state_dict = os.path.join(
                args.oss_cache_dir, f"optimizer_{args.resume}.pt")
            _scheduler_state_dict = os.path.join(
                args.oss_cache_dir, f"scheduler_{args.resume}.pt")

            amp.load_state_dict(
                torch.load(load_buffer_from_oss(_amp_state_dict)))
            optimizer.load_state_dict(
                torch.load(load_buffer_from_oss(_optimizer_state_dict)))
            scheduler.load_state_dict(
                torch.load(load_buffer_from_oss(_scheduler_state_dict)))

            logger.info(f"Loaded resumed state dict of step {args.resume}")

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_features))
        logger.info("  Instantaneous batch size per GPU = %d",
                    args.train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            args.train_batch_size * args.gradient_accumulation_steps *
            (dist.get_world_size() if args.local_rank != -1 else 1),
        )
        logger.info("  Gradient Accumulation steps = %d",
                    args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        model.train()
        epc = 0
        # test
        if args.local_rank in [-1, 0]:
            if args.fp16:
                amp_file = os.path.join(args.oss_cache_dir,
                                        f"amp_{global_step}.bin")
                torch_save_to_oss(amp.state_dict(), amp_file)
            optimizer_file = os.path.join(args.oss_cache_dir,
                                          f"optimizer_{global_step}.pt")
            torch_save_to_oss(optimizer.state_dict(), optimizer_file)
            scheduler_file = os.path.join(args.oss_cache_dir,
                                          f"scheduler_{global_step}.pt")
            torch_save_to_oss(scheduler.state_dict(), scheduler_file)

        tr_loss = 0
        for _ in range(int(args.num_train_epochs)):
            logger.info('Epoch ' + str(epc + 1))

            TOTAL_NUM = len(train_features)
            train_start_index = 0
            CHUNK_NUM = 8
            train_chunk = TOTAL_NUM // CHUNK_NUM
            chunk_index = 0

            random.shuffle(train_features)

            save_retry = False
            while train_start_index < TOTAL_NUM:
                train_end_index = min(train_start_index + train_chunk - 1,
                                      TOTAL_NUM - 1)
                chunk_len = train_end_index - train_start_index + 1

                if args.resume is not None and global_step < args.resume:
                    _chunk_steps = int(
                        math.ceil(chunk_len * 1.0 / args.train_batch_size /
                                  (1 if args.local_rank == -1 else
                                   dist.get_world_size())))
                    _chunk_steps = _chunk_steps // args.gradient_accumulation_steps
                    if global_step + _chunk_steps <= args.resume:
                        global_step += _chunk_steps
                        train_start_index = train_end_index + 1
                        continue

                train_features_ = train_features[
                    train_start_index:train_start_index + chunk_len]

                all_input_ids = torch.tensor(
                    [f.input_ids for f in train_features_], dtype=torch.long)
                all_input_masks = torch.tensor(
                    [f.input_masks for f in train_features_], dtype=torch.long)
                all_segment_ids = torch.tensor(
                    [f.segment_ids for f in train_features_], dtype=torch.long)
                all_output_masks = torch.tensor(
                    [f.output_masks for f in train_features_],
                    dtype=torch.float)
                all_num_paragraphs = torch.tensor(
                    [f.num_paragraphs for f in train_features_],
                    dtype=torch.long)
                all_num_steps = torch.tensor(
                    [f.num_steps for f in train_features_], dtype=torch.long)
                train_data = TensorDataset(all_input_ids, all_input_masks,
                                           all_segment_ids, all_output_masks,
                                           all_num_paragraphs, all_num_steps)

                if args.local_rank != -1:
                    train_sampler = torch.utils.data.DistributedSampler(
                        train_data)
                else:
                    train_sampler = RandomSampler(train_data)
                train_dataloader = DataLoader(train_data,
                                              sampler=train_sampler,
                                              batch_size=args.train_batch_size,
                                              pin_memory=True,
                                              num_workers=4)

                if args.local_rank != -1:
                    train_dataloader.sampler.set_epoch(epc)

                logger.info('Examples from ' + str(train_start_index) +
                            ' to ' + str(train_end_index))
                for step, batch in enumerate(
                        tqdm(train_dataloader,
                             desc="Iteration",
                             disable=args.local_rank not in [-1, 0])):
                    if args.resume is not None and global_step < args.resume:
                        if (step + 1) % args.gradient_accumulation_steps == 0:
                            global_step += 1
                        continue

                    input_masks = batch[1]
                    batch_max_len = input_masks.sum(dim=2).max().item()

                    num_paragraphs = batch[4]
                    batch_max_para_num = num_paragraphs.max().item()

                    num_steps = batch[5]
                    batch_max_steps = num_steps.max().item()

                    # output_masks_cpu = (batch[3])[:, :batch_max_steps, :batch_max_para_num + 1]

                    batch = tuple(t.to(device) for t in batch)
                    input_ids, input_masks, segment_ids, output_masks, _, _ = batch
                    B = input_ids.size(0)

                    input_ids = input_ids[:, :batch_max_para_num, :
                                          batch_max_len]
                    input_masks = input_masks[:, :batch_max_para_num, :
                                              batch_max_len]
                    segment_ids = segment_ids[:, :batch_max_para_num, :
                                              batch_max_len]
                    output_masks = output_masks[:, :batch_max_steps, :
                                                batch_max_para_num +
                                                1]  # 1 for EOE

                    target = torch.zeros(output_masks.size()).fill_(
                        NEGATIVE)  # (B, NUM_STEPS, |P|+1) <- 1 for EOE
                    for i in range(B):
                        output_masks[i, :num_steps[i], -1] = 1.0  # for EOE

                        for j in range(num_steps[i].item() - 1):
                            target[i, j, j].fill_(POSITIVE)

                        target[i, num_steps[i] - 1, -1].fill_(POSITIVE)
                    target = target.to(device)

                    neg_start = batch_max_steps - 1
                    while neg_start < batch_max_para_num:
                        neg_end = min(neg_start + args.neg_chunk - 1,
                                      batch_max_para_num - 1)
                        neg_len = (neg_end - neg_start + 1)

                        input_ids_ = torch.cat(
                            (input_ids[:, :batch_max_steps - 1, :],
                             input_ids[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        input_masks_ = torch.cat(
                            (input_masks[:, :batch_max_steps - 1, :],
                             input_masks[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        segment_ids_ = torch.cat(
                            (segment_ids[:, :batch_max_steps - 1, :],
                             segment_ids[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        output_masks_ = torch.cat(
                            (output_masks[:, :, :batch_max_steps - 1],
                             output_masks[:, :, neg_start:neg_start + neg_len],
                             output_masks[:, :, batch_max_para_num:
                                          batch_max_para_num + 1]),
                            dim=2)
                        target_ = torch.cat(
                            (target[:, :, :batch_max_steps - 1],
                             target[:, :, neg_start:neg_start + neg_len],
                             target[:, :,
                                    batch_max_para_num:batch_max_para_num +
                                    1]),
                            dim=2)

                        if neg_start != batch_max_steps - 1:
                            output_masks_[:, :, :batch_max_steps - 1] = 0.0
                            output_masks_[:, :, -1] = 0.0

                        loss = model(input_ids_, segment_ids_, input_masks_,
                                     output_masks_, target_, batch_max_steps)

                        if n_gpu > 1:
                            loss = loss.mean(
                            )  # mean() to average on multi-gpu.
                        if args.gradient_accumulation_steps > 1:
                            loss = loss / args.gradient_accumulation_steps

                        if args.fp16:
                            with amp.scale_loss(loss,
                                                optimizer) as scaled_loss:
                                scaled_loss.backward()
                        else:
                            loss.backward()

                        tr_loss += loss.item()
                        neg_start = neg_end + 1

                        # del input_ids_
                        # del input_masks_
                        # del segment_ids_
                        # del output_masks_
                        # del target_

                    if (step + 1) % args.gradient_accumulation_steps == 0:

                        if args.fp16:
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(optimizer), 1.0)
                        else:
                            torch.nn.utils.clip_grad_norm_(
                                model.parameters(), 1.0)

                        optimizer.step()
                        scheduler.step()
                        # optimizer.zero_grad()
                        model.zero_grad()
                        global_step += 1

                        if global_step % 50 == 0:
                            _cur_steps = global_step if args.resume is None else global_step - args.resume
                            logger.info(
                                f"Training loss: {tr_loss / _cur_steps}\t"
                                f"Learning rate: {scheduler.get_lr()[0]}\t"
                                f"Global step: {global_step}")

                        if global_step % args.save_steps == 0:
                            if args.local_rank in [-1, 0]:
                                model_to_save = model.module if hasattr(
                                    model, 'module') else model
                                output_model_file = os.path.join(
                                    args.oss_cache_dir,
                                    f"pytorch_model_{global_step}.bin")
                                torch_save_to_oss(model_to_save.state_dict(),
                                                  output_model_file)

                            _suffix = "" if args.local_rank == -1 else f"_{args.local_rank}"
                            if args.fp16:
                                amp_file = os.path.join(
                                    args.oss_cache_dir,
                                    f"amp_{global_step}{_suffix}.bin")
                                torch_save_to_oss(amp.state_dict(), amp_file)
                            optimizer_file = os.path.join(
                                args.oss_cache_dir,
                                f"optimizer_{global_step}{_suffix}.pt")
                            torch_save_to_oss(optimizer.state_dict(),
                                              optimizer_file)
                            scheduler_file = os.path.join(
                                args.oss_cache_dir,
                                f"scheduler_{global_step}{_suffix}.pt")
                            torch_save_to_oss(scheduler.state_dict(),
                                              scheduler_file)

                            logger.info(
                                f"checkpoint of step {global_step} is saved to oss."
                            )

                    # del input_ids
                    # del input_masks
                    # del segment_ids
                    # del output_masks
                    # del target
                    # del batch

                chunk_index += 1
                train_start_index = train_end_index + 1

                # Save the model at the half of the epoch
                if (chunk_index == CHUNK_NUM // 2
                        or save_retry) and args.local_rank in [-1, 0]:
                    status = save(model, args.output_dir, str(epc + 0.5))
                    save_retry = (not status)

                del train_features_
                del all_input_ids
                del all_input_masks
                del all_segment_ids
                del all_output_masks
                del all_num_paragraphs
                del all_num_steps
                del train_data
                del train_sampler
                del train_dataloader
                gc.collect()

            # Save the model at the end of the epoch
            if args.local_rank in [-1, 0]:
                save(model, args.output_dir, str(epc + 1))
                # model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
                # output_model_file = os.path.join(args.oss_cache_dir, "pytorch_model_" + str(epc + 1) + ".bin")
                # torch_save_to_oss(model_to_save.state_dict(), output_model_file)

            epc += 1

    if do_train:
        return

    ##############################
    # Evaluation                 #
    ##############################
    assert args.model_suffix is not None

    if graph_retriever_config.db_save_path is not None:
        import sys
        sys.path.append('../')
        from pipeline.tfidf_retriever import TfidfRetriever
        tfidf_retriever = TfidfRetriever(graph_retriever_config.db_save_path,
                                         None)
    else:
        tfidf_retriever = None

    if args.oss_cache_dir is not None:
        file_name = 'pytorch_model_' + args.model_suffix + '.bin'
        model_state_dict = torch.load(
            load_buffer_from_oss(os.path.join(args.oss_cache_dir, file_name)))
    else:
        model_state_dict = load(args.output_dir, args.model_suffix)

    model = RobertaForGraphRetriever.from_pretrained(
        args.bert_model,
        state_dict=model_state_dict,
        graph_retriever_config=graph_retriever_config)
    model.to(device)

    model.eval()

    if args.pred_file is not None:
        pred_output = []

    eval_examples = processor.get_dev_examples(graph_retriever_config)

    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)

    TOTAL_NUM = len(eval_examples)
    eval_start_index = 0

    while eval_start_index < TOTAL_NUM:
        eval_end_index = min(
            eval_start_index + graph_retriever_config.eval_chunk - 1,
            TOTAL_NUM - 1)
        chunk_len = eval_end_index - eval_start_index + 1

        eval_features = convert_examples_to_features(
            eval_examples[eval_start_index:eval_start_index + chunk_len],
            args.max_seq_length, args.max_para_num, graph_retriever_config,
            tokenizer)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_masks = torch.tensor([f.input_masks for f in eval_features],
                                       dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_output_masks = torch.tensor(
            [f.output_masks for f in eval_features], dtype=torch.float)
        all_num_paragraphs = torch.tensor(
            [f.num_paragraphs for f in eval_features], dtype=torch.long)
        all_num_steps = torch.tensor([f.num_steps for f in eval_features],
                                     dtype=torch.long)
        all_ex_indices = torch.tensor([f.ex_index for f in eval_features],
                                      dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_masks,
                                  all_segment_ids, all_output_masks,
                                  all_num_paragraphs, all_num_steps,
                                  all_ex_indices)

        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        for input_ids, input_masks, segment_ids, output_masks, num_paragraphs, num_steps, ex_indices in tqdm(
                eval_dataloader, desc="Evaluating"):
            batch_max_len = input_masks.sum(dim=2).max().item()
            batch_max_para_num = num_paragraphs.max().item()

            batch_max_steps = num_steps.max().item()

            input_ids = input_ids[:, :batch_max_para_num, :batch_max_len]
            input_masks = input_masks[:, :batch_max_para_num, :batch_max_len]
            segment_ids = segment_ids[:, :batch_max_para_num, :batch_max_len]
            output_masks = output_masks[:, :batch_max_para_num +
                                        2, :batch_max_para_num + 1]
            output_masks[:, 1:, -1] = 1.0  # Ignore EOE in the first step

            input_ids = input_ids.to(device)
            input_masks = input_masks.to(device)
            segment_ids = segment_ids.to(device)
            output_masks = output_masks.to(device)

            examples = [
                eval_examples[eval_start_index + ex_indices[i].item()]
                for i in range(input_ids.size(0))
            ]

            with torch.no_grad():
                pred, prob, topk_pred, topk_prob = model.beam_search(
                    input_ids,
                    segment_ids,
                    input_masks,
                    examples=examples,
                    tokenizer=tokenizer,
                    retriever=tfidf_retriever,
                    split_chunk=args.split_chunk)

            for i in range(len(pred)):
                e = examples[i]
                titles = [e.title_order[p] for p in pred[i]]

                # Output predictions to a file
                if args.pred_file is not None:
                    pred_output.append({})
                    pred_output[-1]['q_id'] = e.guid

                    pred_output[-1]['titles'] = titles
                    pred_output[-1]['probs'] = []
                    for prob_ in prob[i]:
                        entry = {'EOE': prob_[-1]}
                        for j in range(len(e.title_order)):
                            entry[e.title_order[j]] = prob_[j]
                        pred_output[-1]['probs'].append(entry)

                    topk_titles = [[e.title_order[p] for p in topk_pred[i][j]]
                                   for j in range(len(topk_pred[i]))]
                    pred_output[-1]['topk_titles'] = topk_titles

                    topk_probs = []
                    for k in range(len(topk_prob[i])):
                        topk_probs.append([])
                        for prob_ in topk_prob[i][k]:
                            entry = {'EOE': prob_[-1]}
                            for j in range(len(e.title_order)):
                                entry[e.title_order[j]] = prob_[j]
                            topk_probs[-1].append(entry)
                    pred_output[-1]['topk_probs'] = topk_probs

                    # Output the selected paragraphs
                    context = {}
                    for ts in topk_titles:
                        for t in ts:
                            context[t] = e.all_paras[t]
                    pred_output[-1]['context'] = context

        eval_start_index = eval_end_index + 1

        del eval_features
        del all_input_ids
        del all_input_masks
        del all_segment_ids
        del all_output_masks
        del all_num_paragraphs
        del all_num_steps
        del all_ex_indices
        del eval_data

    if args.pred_file is not None:
        json.dump(pred_output, open(args.pred_file, 'w'))
Beispiel #6
0
                                             batch_size=50,
                                             shuffle=False,
                                             num_workers=opt.num_workers)

    model = BertHierAttNet(len(label2idx), opt.bert_path)
    if opt.gpu:
        model = nn.DataParallel(model)
        model.cuda()
    optimizer = AdamW(model.parameters(), lr=opt.lr)

    max_f1 = 0
    for epoch in range(opt.nepoch):
        loss = train(model, trainloader, optimizer, opt)
        acc, f1_micro, f1_macro = test(model, testloader, opt)
        print("Epoch:%d loss:%f Acc:%.3f F1_micro:%.3f F1_macro:%.3f" %
              (epoch, loss, acc, f1_micro, f1_macro))
        if f1_micro + f1_macro > max_f1:
            max_f1 = f1_micro + f1_macro
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }, opt.model_path)

    checkpoint = torch.load(opt.model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    acc, f1_micro, f1_macro = test(model, testloader, opt, wr=True)
    print("Epoch:%d Acc:%.3f F1_micro:%.3f F1_macro:%.3f" %
          (checkpoint['epoch'], acc, f1_micro, f1_macro))
            print("==========================================")

        if ckpt_every > 0 and len(total_score_history) > ckpt_lookback:
            current_score = np.mean(total_score_history[-ckpt_lookback:])

            if time.time() - time_ckpt > ckpt_every:
                revert_ckpt = best_ckpt_score is not None and current_score < min(
                    1.2 * best_ckpt_score,
                    0.8 * best_ckpt_score)  # Could be negative or positive
                print("================================== CKPT TIME, " +
                      str(datetime.now()) +
                      " =================================")
                print("Previous best:", best_ckpt_score)
                print("Current Score:", current_score)
                print("[CKPT] Am I reverting?",
                      ("yes" if revert_ckpt else "no! BEST CKPT"))
                if revert_ckpt:
                    summarizer.model.load_state_dict(torch.load(ckpt_file))
                    optimizer.load_state_dict(torch.load(ckpt_optimizer_file))
                time_ckpt = time.time()
                print(
                    "=============================================================================="
                )

            if best_ckpt_score is None or current_score > best_ckpt_score:
                print("[CKPT] Saved new best at: %.3f %s" %
                      (current_score, "[" + str(datetime.now()) + "]"))
                best_ckpt_score = current_score
                torch.save(summarizer.model.state_dict(), ckpt_file)
                torch.save(optimizer.state_dict(), ckpt_optimizer_file)
Beispiel #8
0
 if torch.cuda.device_count()>1:
     loss=loss.mean()
 loss.backward()
 torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)  # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
 optimizer.step()
 scheduler.step()
 model.zero_grad()
 if step % 10 == 0:
     pred=torch.max(logits,1)[1]
     correct=pred.eq(labels).cpu().numpy()
     print('step: %d  loss: %.2f   accuracy: %.2f%%' % (step,loss,(np.sum(correct) / len(correct) * 100)))
 if step and step % 100 == 0:
     state = {
         'epoch': epoch,
         'state_dict': model.state_dict(),
         'optimizer': optimizer.state_dict(),
         'step': step,
         'accuracy':(np.sum(correct) / len(correct) * 100),
         'loss':loss
     }
     os.makedirs('trained_model',exist_ok=True)
     filename='./trained_model/model_'+time.strftime("%Y%m%d")+'_epoch_'+str(epoch)+'_step_'+str(step) +'.mdl'
     torch.save(state,filename)
     chpt=torch.load(filename)
     model.load_state_dict(chpt['state_dict'])
     model.eval()
     eval_accuracy,eval_loss=0,0
     true_labels=[]
     pred_labels=[]
     pred_probs=[]
     admids=[]
Beispiel #9
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument(
        "--from_pretrained",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument(
        "--output_dir",
        default="save",
        type=str,
        help="The output directory where the model checkpoints will be written.",
    )
    parser.add_argument(
        "--config_file",
        default="config/bert_base_6layer_6conect.json",
        type=str,
        help="The config file which specified the model details.",
    )
    parser.add_argument(
        "--num_train_epochs",
        default=20,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--train_iter_multiplier",
        default=1.0,
        type=float,
        help="multiplier for the multi-task training.",
    )
    parser.add_argument(
        "--train_iter_gap",
        default=4,
        type=int,
        help="forward every n iteration is the validation score is not improving over the last 3 epoch, -1 means will stop",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help="Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument(
        "--no_cuda", action="store_true", help="Whether not to use CUDA when available"
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        type=bool,
        help="Whether to lower case the input text. True for uncased models, False for cased models.",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="local_rank for distributed training on gpus",
    )
    parser.add_argument(
        "--seed", type=int, default=0, help="random seed for initialization"
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumualte before performing a backward/update pass.",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=16,
        help="Number of workers in the dataloader.",
    )
    parser.add_argument(
        "--save_name", default="", type=str, help="save name for training."
    )
    parser.add_argument(
        "--in_memory",
        default=False,
        type=bool,
        help="whether use chunck for parallel training.",
    )
    parser.add_argument(
        "--optim", default="AdamW", type=str, help="what to use for the optimization."
    )
    parser.add_argument(
        "--tasks", default="", type=str, help="1-2-3... training task separate by -"
    )
    parser.add_argument(
        "--freeze",
        default=-1,
        type=int,
        help="till which layer of textual stream of vilbert need to fixed.",
    )
    parser.add_argument(
        "--vision_scratch",
        action="store_true",
        help="whether pre-trained the image or not.",
    )
    parser.add_argument(
        "--evaluation_interval", default=1, type=int, help="evaluate very n epoch."
    )
    parser.add_argument(
        "--lr_scheduler",
        default="mannul",
        type=str,
        help="whether use learning rate scheduler.",
    )
    parser.add_argument(
        "--baseline", action="store_true", help="whether use single stream baseline."
    )
    parser.add_argument(
        "--resume_file", default="", type=str, help="Resume from checkpoint"
    )
    parser.add_argument(
        "--dynamic_attention",
        action="store_true",
        help="whether use dynamic attention.",
    )
    parser.add_argument(
        "--clean_train_sets",
        default=True,
        type=bool,
        help="whether clean train sets for multitask data.",
    )
    parser.add_argument(
        "--visual_target",
        default=0,
        type=int,
        help="which target to use for visual branch. \
        0: soft label, \
        1: regress the feature, \
        2: NCE loss.",
    )
    parser.add_argument(
        "--task_specific_tokens",
        action="store_true",
        help="whether to use task specific tokens for the multi-task learning.",
    )

    args = parser.parse_args()
    with open("vilbert_tasks.yml", "r") as f:
        task_cfg = edict(yaml.safe_load(f))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.baseline:
        from pytorch_transformers.modeling_bert import BertConfig
        from vilbert.basebert import BaseBertForVLTasks
    else:
        from vilbert.vilbert import BertConfig
        from vilbert.vilbert import VILBertForVLTasks

    task_names = []
    task_lr = []
    for i, task_id in enumerate(args.tasks.split("-")):
        task = "TASK" + task_id
        name = task_cfg[task]["name"]
        task_names.append(name)
        task_lr.append(task_cfg[task]["lr"])

    base_lr = min(task_lr)
    loss_scale = {}
    for i, task_id in enumerate(args.tasks.split("-")):
        task = "TASK" + task_id
        loss_scale[task] = task_lr[i] / base_lr

    if args.save_name:
        prefix = "-" + args.save_name
    else:
        prefix = ""
    timeStamp = (
        "-".join(task_names)
        + "_"
        + args.config_file.split("/")[1].split(".")[0]
        + prefix
    )
    savePath = os.path.join(args.output_dir, timeStamp)

    bert_weight_name = json.load(
        open("config/" + args.bert_model + "_weight_name.json", "r")
    )

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        )
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        torch.distributed.init_process_group(backend="nccl")

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
            device, n_gpu, bool(args.local_rank != -1), args.fp16
        )
    )

    default_gpu = False
    if dist.is_available() and args.local_rank != -1:
        rank = dist.get_rank()
        if rank == 0:
            default_gpu = True
    else:
        default_gpu = True

    if default_gpu:
        if not os.path.exists(savePath):
            os.makedirs(savePath)

    config = BertConfig.from_json_file(args.config_file)
    if default_gpu:
        # save all the hidden parameters.
        with open(os.path.join(savePath, "command.txt"), "w") as f:
            print(args, file=f)  # Python 3.x
            print("\n", file=f)
            print(config, file=f)

    task_batch_size, task_num_iters, task_ids, task_datasets_train, task_datasets_val, task_dataloader_train, task_dataloader_val = LoadDatasets(
        args, task_cfg, args.tasks.split("-")
    )

    logdir = os.path.join(savePath, "logs")
    tbLogger = utils.tbLogger(
        logdir,
        savePath,
        task_names,
        task_ids,
        task_num_iters,
        args.gradient_accumulation_steps,
    )

    if args.visual_target == 0:
        config.v_target_size = 1601
        config.visual_target = args.visual_target
    else:
        config.v_target_size = 2048
        config.visual_target = args.visual_target

    if args.task_specific_tokens:
        config.task_specific_tokens = True

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_ave_iter = {}
    task_stop_controller = {}
    for task_id, num_iter in task_num_iters.items():
        task_ave_iter[task_id] = int(
            task_cfg[task]["num_epoch"]
            * num_iter
            * args.train_iter_multiplier
            / args.num_train_epochs
        )
        task_stop_controller[task_id] = utils.MultiTaskStopOnPlateau(
            mode="max",
            patience=1,
            continue_threshold=0.005,
            cooldown=1,
            threshold=0.001,
        )

    task_ave_iter_list = sorted(task_ave_iter.values())
    median_num_iter = task_ave_iter_list[-1]
    num_train_optimization_steps = (
        median_num_iter * args.num_train_epochs // args.gradient_accumulation_steps
    )
    num_labels = max([dataset.num_labels for dataset in task_datasets_train.values()])

    if args.dynamic_attention:
        config.dynamic_attention = True
    if "roberta" in args.bert_model:
        config.model = "roberta"

    if args.baseline:
        model = BaseBertForVLTasks.from_pretrained(
            args.from_pretrained,
            config=config,
            num_labels=num_labels,
            default_gpu=default_gpu,
        )
    else:
        model = VILBertForVLTasks.from_pretrained(
            args.from_pretrained,
            config=config,
            num_labels=num_labels,
            default_gpu=default_gpu,
        )

    task_losses = LoadLosses(args, task_cfg, args.tasks.split("-"))

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]

    if args.freeze != -1:
        bert_weight_name_filtered = []
        for name in bert_weight_name:
            if "embeddings" in name:
                bert_weight_name_filtered.append(name)
            elif "encoder" in name:
                layer_num = name.split(".")[2]
                if int(layer_num) <= args.freeze:
                    bert_weight_name_filtered.append(name)

        optimizer_grouped_parameters = []
        for key, value in dict(model.named_parameters()).items():
            if key[12:] in bert_weight_name_filtered:
                value.requires_grad = False

        if default_gpu:
            print("filtered weight")
            print(bert_weight_name_filtered)

    optimizer_grouped_parameters = []
    for key, value in dict(model.named_parameters()).items():
        if value.requires_grad:
            if "vil_" in key:
                lr = 1e-4
            else:
                if args.vision_scratch:
                    if key[12:] in bert_weight_name:
                        lr = base_lr
                    else:
                        lr = 1e-4
                else:
                    lr = base_lr
            if any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [
                    {"params": [value], "lr": lr, "weight_decay": 0.0}
                ]
            if not any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [
                    {"params": [value], "lr": lr, "weight_decay": 0.01}
                ]

    if default_gpu:
        print(len(list(model.named_parameters())), len(optimizer_grouped_parameters))

    if args.optim == "AdamW":
        optimizer = AdamW(optimizer_grouped_parameters, lr=base_lr, correct_bias=False)
    elif args.optim == "RAdam":
        optimizer = RAdam(optimizer_grouped_parameters, lr=base_lr)

    warmpu_steps = args.warmup_proportion * num_train_optimization_steps

    if args.lr_scheduler == "warmup_linear":
        warmup_scheduler = WarmupLinearSchedule(
            optimizer, warmup_steps=warmpu_steps, t_total=num_train_optimization_steps
        )
    else:
        warmup_scheduler = WarmupConstantSchedule(optimizer, warmup_steps=warmpu_steps)

    lr_reduce_list = np.array([5, 7])
    if args.lr_scheduler == "automatic":
        lr_scheduler = ReduceLROnPlateau(
            optimizer, mode="max", factor=0.2, patience=1, cooldown=1, threshold=0.001
        )
    elif args.lr_scheduler == "cosine":
        lr_scheduler = CosineAnnealingLR(
            optimizer, T_max=median_num_iter * args.num_train_epochs
        )
    elif args.lr_scheduler == "cosine_warm":
        lr_scheduler = CosineAnnealingWarmRestarts(
            optimizer, T_0=median_num_iter * args.num_train_epochs
        )
    elif args.lr_scheduler == "mannul":

        def lr_lambda_fun(epoch):
            return pow(0.2, np.sum(lr_reduce_list <= epoch))

        lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda_fun)

    startIterID = 0
    global_step = 0
    start_epoch = 0

    if args.resume_file != "" and os.path.exists(args.resume_file):
        checkpoint = torch.load(args.resume_file, map_location="cpu")
        new_dict = {}
        for attr in checkpoint["model_state_dict"]:
            if attr.startswith("module."):
                new_dict[attr.replace("module.", "", 1)] = checkpoint[
                    "model_state_dict"
                ][attr]
            else:
                new_dict[attr] = checkpoint["model_state_dict"][attr]
        model.load_state_dict(new_dict)
        warmup_scheduler.load_state_dict(checkpoint["warmup_scheduler_state_dict"])
        # lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        global_step = checkpoint["global_step"]
        start_epoch = int(checkpoint["epoch_id"]) + 1
        task_stop_controller = checkpoint["task_stop_controller"]
        tbLogger = checkpoint["tb_logger"]
        del checkpoint

    model.to(device)

    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.cuda()

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model, delay_allreduce=True)

    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if default_gpu:
        print("***** Running training *****")
        print("  Num Iters: ", task_num_iters)
        print("  Batch size: ", task_batch_size)
        print("  Num steps: %d" % num_train_optimization_steps)

    task_iter_train = {name: None for name in task_ids}
    task_count = {name: 0 for name in task_ids}
    for epochId in tqdm(range(start_epoch, args.num_train_epochs), desc="Epoch"):
        model.train()
        for step in range(median_num_iter):
            iterId = startIterID + step + (epochId * median_num_iter)
            first_task = True
            for task_id in task_ids:
                is_forward = False
                if (not task_stop_controller[task_id].in_stop) or (
                    iterId % args.train_iter_gap == 0
                ):
                    is_forward = True

                if is_forward:
                    loss, score = ForwardModelsTrain(
                        args,
                        task_cfg,
                        device,
                        task_id,
                        task_count,
                        task_iter_train,
                        task_dataloader_train,
                        model,
                        task_losses,
                    )

                    loss = loss * loss_scale[task_id]
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    loss.backward()
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        if args.fp16:
                            lr_this_step = args.learning_rate * warmup_linear(
                                global_step / num_train_optimization_steps,
                                args.warmup_proportion,
                            )
                            for param_group in optimizer.param_groups:
                                param_group["lr"] = lr_this_step

                        if first_task and (
                            global_step < warmpu_steps
                            or args.lr_scheduler == "warmup_linear"
                        ):
                            warmup_scheduler.step()

                        optimizer.step()
                        model.zero_grad()
                        if first_task:
                            global_step += 1
                            first_task = False

                        if default_gpu:
                            tbLogger.step_train(
                                epochId,
                                iterId,
                                float(loss),
                                float(score),
                                optimizer.param_groups[0]["lr"],
                                task_id,
                                "train",
                            )

            if "cosine" in args.lr_scheduler and global_step > warmpu_steps:
                lr_scheduler.step()

            if (
                step % (20 * args.gradient_accumulation_steps) == 0
                and step != 0
                and default_gpu
            ):
                tbLogger.showLossTrain()

            # decided whether to evaluate on each tasks.
            for task_id in task_ids:
                if (iterId != 0 and iterId % task_num_iters[task_id] == 0) or (
                    epochId == args.num_train_epochs - 1 and step == median_num_iter - 1
                ):
                    evaluate(
                        args,
                        task_dataloader_val,
                        task_stop_controller,
                        task_cfg,
                        device,
                        task_id,
                        model,
                        task_losses,
                        epochId,
                        default_gpu,
                        tbLogger,
                    )

        if args.lr_scheduler == "automatic":
            lr_scheduler.step(sum(val_scores.values()))
            logger.info("best average score is %3f" % lr_scheduler.best)
        elif args.lr_scheduler == "mannul":
            lr_scheduler.step()

        if epochId in lr_reduce_list:
            for task_id in task_ids:
                # reset the task_stop_controller once the lr drop
                task_stop_controller[task_id]._reset()

        if default_gpu:
            # Save a trained model
            logger.info("** ** * Saving fine - tuned model ** ** * ")
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Only save the model it-self
            output_model_file = os.path.join(
                savePath, "pytorch_model_" + str(epochId) + ".bin"
            )
            output_checkpoint = os.path.join(savePath, "pytorch_ckpt_latest.tar")
            torch.save(model_to_save.state_dict(), output_model_file)
            torch.save(
                {
                    "model_state_dict": model_to_save.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "warmup_scheduler_state_dict": warmup_scheduler.state_dict(),
                    # 'lr_scheduler_state_dict': lr_scheduler.state_dict(),
                    "global_step": global_step,
                    "epoch_id": epochId,
                    "task_stop_controller": task_stop_controller,
                    "tb_logger": tbLogger,
                },
                output_checkpoint,
            )
    tbLogger.txt_close()
def model_train_validate_test(train_df,
                              dev_df,
                              test_df,
                              target_dir,
                              max_seq_len=50,
                              epochs=3,
                              batch_size=32,
                              lr=2e-05,
                              patience=1,
                              max_grad_norm=10.0,
                              if_save_model=True,
                              checkpoint=None):
    """
    Parameters
    ----------
    train_df : pandas dataframe of train set.
    dev_df : pandas dataframe of dev set.
    test_df : pandas dataframe of test set.
    target_dir : the path where you want to save model.
    max_seq_len: the max truncated length.
    epochs : the default is 3.
    batch_size : the default is 32.
    lr : learning rate, the default is 2e-05.
    patience : the default is 1.
    max_grad_norm : the default is 10.0.
    if_save_model: if save the trained model to the target dir.
    checkpoint : the default is None.

    """

    bertmodel = BertModel(requires_grad=True)
    tokenizer = bertmodel.tokenizer

    print(20 * "=", " Preparing for training ", 20 * "=")
    # Path to save the model, create a folder if not exist.
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    # -------------------- Data loading --------------------------------------#

    print("\t* Loading training data...")
    train_data = DataPrecessForSentence(tokenizer,
                                        train_df,
                                        max_seq_len=max_seq_len)
    train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)

    print("\t* Loading validation data...")
    dev_data = DataPrecessForSentence(tokenizer,
                                      dev_df,
                                      max_seq_len=max_seq_len)
    dev_loader = DataLoader(dev_data, shuffle=True, batch_size=batch_size)

    print("\t* Loading test data...")
    test_data = DataPrecessForSentence(tokenizer,
                                       test_df,
                                       max_seq_len=max_seq_len)
    test_loader = DataLoader(test_data, shuffle=False, batch_size=batch_size)

    # -------------------- Model definition ------------------- --------------#

    print("\t* Building model...")
    device = torch.device("cuda")
    model = bertmodel.to(device)

    # -------------------- Preparation for training  -------------------------#

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)

    ## Implement of warm up
    ## total_steps = len(train_loader) * epochs
    ## scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=60, num_training_steps=total_steps)

    # When the monitored value is not improving, the network performance could be improved by reducing the learning rate.
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode="max",
                                                           factor=0.85,
                                                           patience=0)

    best_score = 0.0
    start_epoch = 1
    # Data for loss curves plot
    epochs_count = []
    train_losses = []
    train_accuracies = []
    valid_losses = []
    valid_accuracies = []
    valid_aucs = []

    # Continuing training from a checkpoint if one was given as argument
    if checkpoint:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint["epoch"] + 1
        best_score = checkpoint["best_score"]
        print("\t* Training will continue on existing model from epoch {}...".
              format(start_epoch))
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epochs_count = checkpoint["epochs_count"]
        train_losses = checkpoint["train_losses"]
        train_accuracy = checkpoint["train_accuracy"]
        valid_losses = checkpoint["valid_losses"]
        valid_accuracy = checkpoint["valid_accuracy"]
        valid_auc = checkpoint["valid_auc"]

    # Compute loss and accuracy before starting (or resuming) training.
    _, valid_loss, valid_accuracy, auc, _, = validate(model, dev_loader)
    print(
        "\n* Validation loss before training: {:.4f}, accuracy: {:.4f}%, auc: {:.4f}"
        .format(valid_loss, (valid_accuracy * 100), auc))

    # -------------------- Training epochs -----------------------------------#

    print("\n", 20 * "=", "Training bert model on device: {}".format(device),
          20 * "=")
    patience_counter = 0
    for epoch in range(start_epoch, epochs + 1):
        epochs_count.append(epoch)

        print("* Training epoch {}:".format(epoch))
        epoch_time, epoch_loss, epoch_accuracy = train(model, train_loader,
                                                       optimizer, epoch,
                                                       max_grad_norm)
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_accuracy)
        print("-> Training time: {:.4f}s, loss = {:.4f}, accuracy: {:.4f}%".
              format(epoch_time, epoch_loss, (epoch_accuracy * 100)))

        print("* Validation for epoch {}:".format(epoch))
        epoch_time, epoch_loss, epoch_accuracy, epoch_auc, _, = validate(
            model, dev_loader)
        valid_losses.append(epoch_loss)
        valid_accuracies.append(epoch_accuracy)
        valid_aucs.append(epoch_auc)
        print(
            "-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%, auc: {:.4f}\n"
            .format(epoch_time, epoch_loss, (epoch_accuracy * 100), epoch_auc))

        # Update the optimizer's learning rate with the scheduler.
        scheduler.step(epoch_accuracy)
        ## scheduler.step()

        # Early stopping on validation accuracy.
        if epoch_accuracy < best_score:
            patience_counter += 1
        else:
            best_score = epoch_accuracy
            patience_counter = 0
            if (if_save_model):
                torch.save(
                    {
                        "epoch": epoch,
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "best_score": best_score,
                        "epochs_count": epochs_count,
                        "train_losses": train_losses,
                        "train_accuracy": train_accuracies,
                        "valid_losses": valid_losses,
                        "valid_accuracy": valid_accuracies,
                        "valid_auc": valid_aucs
                    }, os.path.join(target_dir, "best.pth.tar"))
                print("save model succesfully!\n")

            # run model on test set and save the prediction result to csv
            print("* Test for epoch {}:".format(epoch))
            _, _, test_accuracy, _, all_prob = validate(model, test_loader)
            print("Test accuracy: {:.4f}%\n".format(test_accuracy))
            test_prediction = pd.DataFrame({'prob_1': all_prob})
            test_prediction['prob_0'] = 1 - test_prediction['prob_1']
            test_prediction['prediction'] = test_prediction.apply(
                lambda x: 0 if (x['prob_0'] > x['prob_1']) else 1, axis=1)
            test_prediction = test_prediction[[
                'prob_0', 'prob_1', 'prediction'
            ]]
            test_prediction.to_csv(os.path.join(target_dir,
                                                "test_prediction.csv"),
                                   index=False)

        if patience_counter >= patience:
            print("-> Early stopping: patience limit reached, stopping...")
            break
Beispiel #11
0
def main():
    setseed(args.seed)
    print('Loading data...')
    train = pd.read_csv(args.data_dir + 'train_0_clean.tsv', sep='\t')
    valid = pd.read_csv(args.data_dir + 'valid_0_clean.tsv', sep='\t')

    train['Description'] = train['Description'].fillna('')
    valid['Description'] = valid['Description'].fillna('')
    sub_columns = [
        'Title', 'Description', 'Image_id', 'Product_id', 'Prdtypecode'
    ]

    train_data = ECTextDataset(train[sub_columns].values, args.text_ptm_dir)
    valid_data = ECTextDataset(valid[sub_columns].values, args.text_ptm_dir)

    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=8)
    valid_loader = DataLoader(valid_data,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=8)

    model = TextOnly(args.text_ptm_dir, args.text_ft_dim,
                     args.num_classes).to(device)
    model = nn.DataParallel(model, device_ids=[0, 1])

    if args.test:
        ckpt = torch.load(
            args.result_dir +
            '/best_checkpoint_bertpooled_seed2021_bs128_lr5e-05_ep30_numiter17910_warmup1791.pth.tar'
        )
        model.load_state_dict(ckpt['state_dict'])
        acc, macro_f1, all_preds = validate(valid_loader, model)
        print('Validation Best Results: accuracy: {:.4f}, macro f1: {:.4f}'.
              format(acc, macro_f1))
        valid['preds'] = all_preds
        valid.to_csv(args.result_dir + '/valid_0_preds.tsv',
                     index=False,
                     sep='\t')
        return

    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), args.lr)
    num_training_steps = int(train.shape[0] / args.batch_size) * args.epoch
    num_warmup_steps = int(num_training_steps * 0.1)
    # scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps = num_warmup_steps, num_training_steps = num_training_steps, num_cycles = 6)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps)

    best_f1 = 0
    for epoch in tqdm(range(args.epoch)):
        # train for one epoch
        trainForEpoch(train_loader,
                      model,
                      optimizer,
                      scheduler,
                      epoch,
                      args.epoch,
                      criterion,
                      log_aggr=200)
        acc, macro_f1, all_preds = validate(valid_loader, model)

        # store best loss and save a model checkpoint
        ckpt_dict = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }

        if not os.path.exists(args.result_dir):
            os.makedirs(args.result_dir)

        torch.save(ckpt_dict, args.result_dir + '/latest_checkpoint.pth.tar')

        if macro_f1 > best_f1:
            best_f1 = macro_f1
            torch.save(
                ckpt_dict, args.result_dir +
                '/best_checkpoint_seed{}_bs{}_lr{}_ep{}_numiter{}_warmup{}.pth.tar'
                .format(args.seed, args.batch_size, args.lr, args.epoch,
                        num_training_steps, num_warmup_steps))

        print(
            'Epoch {} validation: accuracy: {:.4f}, macro f1: {:.4f}, best macro f1: {:.4f}'
            .format(epoch, acc, macro_f1, best_f1))
Beispiel #12
0
class CIMClassifier():
    def __init__(self, emb_dim=768, hid_size=32, layers=1, weights_mat=None, tr_labs=None,
                 b_size=24, cp_dir='models/checkpoints/cim', lr=0.001, start_epoch=0, patience=3,
                 step=1, gamma=0.75, n_eps=10, cim_type='cim', context='art'):
        self.start_epoch = start_epoch
        self.cp_dir = cp_dir
        self.device, self.use_cuda = get_torch_device()

        self.emb_dim = emb_dim
        self.hidden_size = hid_size
        self.batch_size = b_size
        if cim_type == 'cim':
            self.criterion = CrossEntropyLoss(weight=torch.tensor([.20, .80], device=self.device), reduction='sum')  # could be made to depend on classweight which should be set on input
        else:
            self.criterion = CrossEntropyLoss(weight=torch.tensor([.25, .75], device=self.device), reduction='sum')  # could be made to depend on classweight which should be set on input

        # self.criterion = NLLLoss(weight=torch.tensor([.15, .85], device=self.device))
        # set criterion on input
        # n_pos = len([l for l in tr_labs if l == 1])
        # class_weight = 1 - (n_pos / len(tr_labs))
        # print(class_weight)
        # self.criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([.85], reduction='sum', dtype=torch.float, device=self.device))

        if start_epoch > 0:
            self.model = self.load_model()
        else:
            self.model = ContextAwareModel(input_size=self.emb_dim, hidden_size=self.hidden_size,
                                           bilstm_layers=layers, weights_matrix=weights_mat,
                                           device=self.device, cam_type=cim_type, context=context)
        self.model = self.model.to(self.device)
        if self.use_cuda: self.model.cuda()

        # empty now and set during or after training
        self.train_time = 0
        self.prev_val_f1 = 0
        self.cp_name = None  # depends on split type and current fold
        self.full_patience = patience
        self.current_patience = self.full_patience
        self.test_perf = []
        self.test_perf_string = ''

        # set optimizer
        nr_train_instances = len(tr_labs)
        nr_train_batches = int(nr_train_instances / b_size)
        half_tr_bs = int(nr_train_instances/2)
        self.optimizer = AdamW(self.model.parameters(), lr=lr, eps=1e-8)

        # set scheduler if desired
        # self.scheduler = lr_scheduler.CyclicLR(self.optimizer, base_lr=lr, step_size_up=half_tr_bs,
        #                                       cycle_momentum=False, max_lr=lr * 30)
        num_train_warmup_steps = int(0.1 * (nr_train_batches * n_eps)) # warmup_proportion
        # self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=num_train_warmup_steps,
        # num_training_steps=num_train_optimization_steps)

    def load_model(self, name):
        cpfp = os.path.join(self.cp_dir, name)
        cp = torch.load(cpfp, map_location=torch.device('cpu'))
        model = cp['model']
        model.load_state_dict(cp['state_dict'])
        self.model = model
        self.model.to(self.device)
        if self.use_cuda: self.model.cuda()
        return model

    def train_on_batch(self, batch):
        batch = tuple(t.to(self.device) for t in batch)
        inputs, labels = batch[:-1], batch[-1]

        self.model.zero_grad()
        logits, probs, _ = self.model(inputs)
        loss = self.criterion(logits.view(-1, 2), labels.view(-1))
        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        #self.scheduler.step()
        return loss.item()

    def save_model(self, name):
        checkpoint = {'model': self.model,
                      'state_dict': self.model.state_dict(),
                      'optimizer': self.optimizer.state_dict()}
        cpfp = os.path.join(self.cp_dir, name)
        torch.save(checkpoint, cpfp)

    def predict(self, batches):
        self.model.eval()

        y_pred = []
        losses = []
        sum_loss = 0
        embeddings = []
        for step, batch in enumerate(batches):
            batch = tuple(t.to(self.device) for t in batch)
            inputs, labels = batch[:-1], batch[-1]

            with torch.no_grad():
                logits, probs, sentence_representation = self.model(inputs)
                loss = self.criterion(logits.view(-1, 2), labels.view(-1))
                # loss = self.criterion(logits.squeeze(), labels)

                embedding = list(sentence_representation.detach().cpu().numpy())
                embeddings.append(embedding)

            loss = loss.detach().cpu().numpy()  # probs.shape: batchsize * num_classes
            probs = probs.detach().cpu().numpy()  # probs.shape: batchsize * num_classes

            losses.append(loss)

            if len(y_pred) == 0:
                y_pred = probs
            else:
                y_pred = np.append(y_pred, probs, axis=0)

                # convert to predictions
                # #preds = [1 if output > 0.5 else 0 for output in sigm_output]
                #y_pred.extend(preds)

            sum_loss += loss.item()

        y_pred = y_pred.squeeze()
        y_pred = np.argmax(y_pred, axis=1)
        # y_pred = [0 if el < 0.5 else 1 for el in y_pred]
        self.model.train()
        return y_pred, sum_loss / len(batches), embeddings, losses
Beispiel #13
0
    def train(self):
        #########################################################################################################################################
        # electra config 객체 생성
        electra_config = ElectraConfig.from_pretrained(
            "/home/mongjin/KuELECTRA_base",
            num_labels=self.config["senti_labels"],
            cache_dir=self.config["cache_dir_path"])

        # electra tokenizer 객체 생성
        electra_tokenizer = ElectraTokenizer.from_pretrained(
            "/home/mongjin/KuELECTRA_base",
            do_lower_case=False,
            cache_dir=self.config["cache_dir_path"])

        # electra model 객체 생성
        electra_model = ElectraForSequenceClassification.from_pretrained(
            "/home/mongjin/KuELECTRA_base",
            config=electra_config,
            lstm_hidden=self.config['lstm_hidden'],
            label_emb_size=self.config['lstm_hidden'] * 2,
            score_emb_size=self.config['lstm_hidden'] * 2,
            score_size=self.config['score_labels'],
            num_layer=self.config['lstm_num_layer'],
            bilstm_flag=self.config['bidirectional_flag'],
            cache_dir=self.config["cache_dir_path"],
            from_tf=True)
        #########################################################################################################################################

        electra_model.cuda()

        # 학습 데이터 읽기
        train_datas = preprocessing.read_data(
            file_path=self.config["train_data_path"], mode=self.config["mode"])

        # 학습 데이터 전처리
        train_dataset = preprocessing.convert_data2dataset(
            datas=train_datas,
            tokenizer=electra_tokenizer,
            max_length=self.config["max_length"],
            labels=self.config["senti_labels"],
            score_labels=self.config["score_labels"],
            mode=self.config["mode"])

        # 학습 데이터를 batch 단위로 추출하기 위한 DataLoader 객체 생성
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=self.config["batch_size"])

        # 평가 데이터 읽기
        test_datas = preprocessing.read_data(
            file_path=self.config["test_data_path"], mode=self.config["mode"])

        # 평가 데이터 전처리
        test_dataset = preprocessing.convert_data2dataset(
            datas=test_datas,
            tokenizer=electra_tokenizer,
            max_length=self.config["max_length"],
            labels=self.config["senti_labels"],
            score_labels=self.config["score_labels"],
            mode=self.config["mode"])

        # 평가 데이터를 batch 단위로 추출하기 위한 DataLoader 객체 생성
        test_sampler = SequentialSampler(test_dataset)
        test_dataloader = DataLoader(test_dataset,
                                     sampler=test_sampler,
                                     batch_size=100)

        # 전체 학습 횟수(batch 단위)
        t_total = len(train_dataloader) // self.config[
            "gradient_accumulation_steps"] * self.config["epoch"]

        # 모델 학습을 위한 optimizer
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer = AdamW([{
            'params': [
                p for n, p in electra_model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'lr':
            5e-5,
            'weight_decay':
            self.config['weight_decay']
        }, {
            'params': [
                p for n, p in electra_model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'lr':
            5e-5,
            'weight_decay':
            0.0
        }])
        # optimizer = AdamW(lan.parameters(), lr=self.config['learning_rate'], eps=self.config['adam_epsilon'])
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.config["warmup_steps"],
            num_training_steps=t_total)

        if os.path.isfile(
                os.path.join(self.config["model_dir_path"],
                             "optimizer.pt")) and os.path.isfile(
                                 os.path.join(self.config["model_dir_path"],
                                              "scheduler.pt")):
            # 기존에 학습했던 optimizer와 scheduler의 정보 불러옴
            optimizer.load_state_dict(
                torch.load(
                    os.path.join(self.config["model_dir_path"],
                                 "optimizer.pt")))
            scheduler.load_state_dict(
                torch.load(
                    os.path.join(self.config["model_dir_path"],
                                 "scheduler.pt")))
            print(
                "#######################     Success Load Model     ###########################"
            )

        global_step = 0
        electra_model.zero_grad()
        max_test_accuracy = 0
        for epoch in range(self.config["epoch"]):
            electra_model.train()

            # 학습 데이터에 대한 정확도와 평균 loss
            train_accuracy, average_loss, global_step, score_acc = self.do_train(
                electra_model=electra_model,
                optimizer=optimizer,
                scheduler=scheduler,
                train_dataloader=train_dataloader,
                epoch=epoch + 1,
                global_step=global_step)

            print("train_accuracy : {}\taverage_loss : {}\n".format(
                round(train_accuracy, 4), round(average_loss, 4)))
            print("train_score_accuracy :", "{:.6f}".format(score_acc))

            electra_model.eval()

            # 평가 데이터에 대한 정확도
            test_accuracy, score_acc = self.do_evaluate(
                electra_model=electra_model,
                test_dataloader=test_dataloader,
                mode=self.config["mode"])

            print("test_accuracy : {}\n".format(round(test_accuracy, 4)))
            print("test_score_accuracy :", "{:.6f}".format(score_acc))

            # 현재의 정확도가 기존 정확도보다 높은 경우 모델 파일 저장
            if (max_test_accuracy < test_accuracy):
                max_test_accuracy = test_accuracy

                output_dir = os.path.join(self.config["model_dir_path"],
                                          "checkpoint-{}".format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)

                electra_config.save_pretrained(output_dir)
                electra_tokenizer.save_pretrained(output_dir)
                electra_model.save_pretrained(output_dir)
                # torch.save(lan.state_dict(), os.path.join(output_dir, "lan.pt"))
                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))

            print("max_test_accuracy :",
                  "{:.6f}".format(round(max_test_accuracy, 4)))
Beispiel #14
0
def train_model(model,
                train_data,
                val_data,
                epochs,
                device,
                save_best_model=True,
                save_checkpoint=True,
                scheduler=True,
                lr=2e-5,
                model_name='best_model',
                output_dir='',
                save_dict=True,
                train_state_path=None):
    """Run the training phase

  Args:
      model (torch.nn.Module): The model to train
      train_data (torch.utils.data.DataLoader): Train dataloader
      val_data (torch.utils.data.DataLoader): Validation dataloader
      epochs (Number): Number of epochs
      device (torch.device): Training platform
      save_best_model (bool, optional): Save the best model (based on validation accuracy). Defaults to True.
      save_checkpoint (bool, optional): Save every model epochs. Defaults to True.
      scheduler (bool, optional): Use LR linear scheduler. Defaults to True.
      lr ([type], optional): Learning Rate. Defaults to 2e-5.
      model_name (str, optional): Output name of the model. Defaults to 'best_model'.
      output_dir (str, optional): Outpur directory. Defaults to ''.
      save_dict (bool, optional): Save dict history of training phase. Defaults to True.
      train_state_path (str, optional): State_dict path to restart training. Defaults to None.

  Returns:
      dict: Traning dict history
  """

    # init optimizer and lr scheduler
    total_steps = len(train_data) * epochs
    optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=0,
                                                num_training_steps=total_steps)

    # init history dict
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': []
    }

    best_accuracy = 0

    if train_state_path is not None:
        model, optimizer, scheduler, history, best_accuracy = load_checkpoint(
            model, optimizer, scheduler, train_state_path)

    # init loss function
    loss_fn = nn.CrossEntropyLoss().to(device)

    # config output directory and checkpoint directory
    if output_dir != '' and not os.path.exists(
            os.path.join(output_dir, model_name)):
        os.makedirs(os.path.join(output_dir, model_name))
        os.makedirs(os.path.join(output_dir, model_name, 'checkpoint'))

    # start training loop
    for epoch in range(epochs):

        print(f'Epoch {epoch + 1}/{epochs}')
        print('-' * 10)

        # training epoch
        train_acc, train_loss = run_train_epoch(model, train_data, loss_fn,
                                                optimizer, device, scheduler)

        print(f'Train loss {train_loss:.4f} accuracy {train_acc:.4f}')
        history['train_acc'].append(train_acc)
        history['train_loss'].append(train_loss)

        # evaluation
        if val_data is not None:
            val_acc, val_loss = run_eval(model, val_data, loss_fn, device)

            print(f'Val loss {val_loss:.4f} accuracy {val_acc:.4f}')
            history['val_acc'].append(val_acc)
            history['val_loss'].append(val_loss)

        # save state dict of model with improved acc on val set
        if val_acc > best_accuracy:
            print('Val_acc improvement, saving model...')
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, model_name, model_name + '-best.bin'))
            best_accuracy = val_acc

        if save_checkpoint:
            print('Saving model, optimizer and scheduler...')
            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'history': history,
                    'best_accuracy': best_accuracy
                },
                os.path.join(output_dir, model_name, 'checkpoint',
                             model_name + '.bin'))

        print()

    if save_dict:
        save_dict_as_pickle(
            history,
            os.path.join(output_dir, model_name, model_name + '_history.pkl'))

    return history