Beispiel #1
0
class TrainerUniter():
    def __init__(self, config):
        self.preds_list, self.probs_list, self.labels_list, self.loss_list, self.short_loss_list, self.id_list = [], [], [], [], [], []
        self.best_val_metrics, self.train_metrics = defaultdict(int), {}
        self.best_auc = 0
        self.not_improved = 0
        self.best_val_loss = 1000
        self.total_iters = 0
        self.terminate_training = False
        self.model_file = os.path.join(config['model_path'],
                                       config['model_save_name'])
        self.pretrained_model_file = None
        if config['pretrained_model_file'] is not None:
            self.pretrained_model_file = os.path.join(
                config['model_path'], config['pretrained_model_file'])
        self.start_epoch = 1
        self.config = config
        self.device = get_device()

        if not isinstance(self.config['test_loader'], list):
            self.config['test_loader'] = [self.config['test_loader']]

        # Initialize the model, optimizer and loss function
        self.init_training_params()

    def init_training_params(self):
        self.init_model()
        wandb.watch(self.model)
        self.model_saver = ModelSaver(self.model_file)

        self.init_optimizer()
        self.init_scheduler()

        if self.config['loss_func'] == 'bce_logits':
            self.criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(
                [self.config['pos_wt']]).to(self.device))
        elif self.config['loss_func'] == 'bce':
            self.criterion = nn.BCELoss()
        else:
            self.criterion = nn.CrossEntropyLoss()

    def init_scheduler(self):
        if self.config['scheduler'] == 'step':
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer,
                step_size=self.config['lr_decay_step'],
                gamma=self.config['lr_decay_factor'])
        elif self.config['scheduler'] == 'multi_step':
            self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
                self.optimizer,
                milestones=[5, 10, 15, 25, 40],
                gamma=self.config['lr_decay_factor'])
        elif self.config['scheduler'] == 'warmup':
            self.scheduler = get_linear_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=self.config['warmup_steps'],
                num_training_steps=len(self.config['train_loader']) *
                self.config['max_epoch'])
        elif self.config['scheduler'] == 'warmup_cosine':
            self.scheduler = get_cosine_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=self.config['warmup_steps'],
                num_training_steps=len(self.config['train_loader']) *
                self.config['max_epoch'])

    def init_optimizer(self):
        self.optimizer = get_optimizer(self.model, self.config)

    def init_model(self):
        # pretrained model file is the original pretrained model - load and use this to fine-tune.
        # If this argument is False, it will load the model file saved by you after fine-tuning
        if self.pretrained_model_file:
            checkpoint = torch.load(self.pretrained_model_file)
            LOGGER.info('Using pretrained UNITER base model {}'.format(
                self.pretrained_model_file))
            base_model = UniterForPretraining.from_pretrained(
                self.config['config'],
                state_dict=checkpoint['model_state_dict'],
                img_dim=IMG_DIM,
                img_label_dim=IMG_LABEL_DIM)
            self.model = MemeUniter(
                uniter_model=base_model.uniter,
                hidden_size=base_model.uniter.config.hidden_size +
                self.config["race_gender_hidden_size"],
                n_classes=self.config['n_classes'])
        else:
            self.load_model()

    def load_model(self):
        # Load pretrained model
        if self.model_file:
            checkpoint = torch.load(self.model_file)
            LOGGER.info('Using UNITER model {}'.format(self.model_file))
        else:
            checkpoint = {}

        uniter_config = UniterConfig.from_json_file(self.config['config'])
        uniter_model = UniterModel(uniter_config, img_dim=IMG_DIM)

        self.model = MemeUniter(uniter_model=uniter_model,
                                hidden_size=uniter_model.config.hidden_size +
                                self.config["race_gender_hidden_size"],
                                n_classes=self.config['n_classes'])
        self.model.load_state_dict(checkpoint['model_state_dict'])

    def average_gradients(self, steps):
        # Used when grad_accumulation > 1
        for param in self.model.parameters():
            if param.requires_grad and param.grad is not None:
                param.grad = param.grad / steps

    def calculate_loss(self, preds, batch_label, grad_step):
        if self.config['loss_func'] == 'bce':
            preds = torch.sigmoid(preds)
        preds = preds.squeeze(1).to(
            self.device
        ) if self.config['loss_func'] == 'bce_logits' else preds.to(
            self.device)
        loss = self.criterion(
            preds,
            batch_label.to(self.device) if self.config['loss_func'] == 'ce'
            else batch_label.float().to(self.device))

        if grad_step and self.iters % self.config['gradient_accumulation'] == 0:
            loss.backward()
            self.average_gradients(steps=self.config['gradient_accumulation'])
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.config['max_grad_norm'])
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()
        elif grad_step:
            loss.backward()

        if self.config['loss_func'] == 'bce':
            probs = preds
            preds = (preds > 0.5).type(torch.FloatTensor)
        elif self.config['loss_func'] == 'ce':
            probs = F.softmax(preds, dim=1)
            preds = torch.argmax(probs, dim=1)
        elif self.config['loss_func'] == 'bce_logits':
            probs = torch.sigmoid(preds)
            preds = (probs > 0.5).type(torch.FloatTensor)

        self.probs_list.append(probs.cpu().detach().numpy())
        self.preds_list.append(preds.cpu().detach().numpy())
        self.labels_list.append(batch_label.cpu().detach().numpy())
        self.loss_list.append(loss.detach().item())
        if grad_step:
            self.short_loss_list.append(loss.detach().item())

    def eval_model(self, test=False, test_idx=0):
        self.model.eval()
        self.preds_list, self.probs_list, self.labels_list, self.loss_list, self.id_list = [], [], [], [], []
        batch_loader = self.config['val_loader'] if not test else self.config[
            'test_loader'][test_idx]
        with torch.no_grad():
            for iters, batch in enumerate(batch_loader):
                batch = self.batch_to_device(batch)
                if batch_loader.dataset.return_ids:
                    self.id_list.append(batch['ids'])
                self.eval_iter_step(iters, batch, test=test)

            self.probs_list = [
                prob for batch_prob in self.probs_list for prob in batch_prob
            ]
            self.preds_list = [
                pred for batch_pred in self.preds_list for pred in batch_pred
            ]
            self.labels_list = [
                label for batch_labels in self.labels_list
                for label in batch_labels
            ]
            self.id_list = [
                data_id for batch_id in self.id_list for data_id in batch_id
            ]

            val_loss = sum(self.loss_list) / len(self.loss_list)
            eval_metrics = standard_metrics(torch.tensor(self.probs_list),
                                            torch.tensor(self.labels_list),
                                            add_optimal_acc=True)
            # if test:
            # 	print(classification_report(np.array(self.labels_list), np.array(self.preds_list)))
        return eval_metrics, val_loss

    @torch.no_grad()
    def export_test_predictions(self, test_idx=0, threshold=0.5):
        self.model.eval()

        # Step 2: Run model on the test set (no loss!)
        # Ensure that ids are actually returned
        assert self.config['test_loader'][
            test_idx].dataset.return_ids, "Can only export test results if the IDs are returned in the test dataset."
        test_name = self.config['test_loader'][test_idx].dataset.name

        prob_list = []
        id_list = []
        for iters, batch in enumerate(self.config['test_loader'][test_idx]):
            batch = self.batch_to_device(batch)
            id_list.append(batch['ids'].cpu())
            probs = self.test_iter_step(batch)
            if self.config['loss_func'] == 'bce_logits':
                probs = torch.sigmoid(probs)
            prob_list.append(probs.detach().cpu())

        probs = torch.cat(prob_list, dim=0)
        ids = torch.cat(id_list, dim=0)
        preds = (probs > threshold).long()

        # Step 3: Export predictions
        self._export_preds(ids,
                           probs,
                           preds,
                           file_postfix="_%s_preds.csv" % test_name)

        LOGGER.info("Finished export of test predictions")

    @torch.no_grad()
    def export_val_predictions(self, test=False, test_idx=0, threshold=0.5):
        batch_loader = self.config['val_loader'] if not test else self.config[
            'test_loader'][test_idx]
        test_name = batch_loader.dataset.name
        LOGGER.info("Exporting %s predictions..." % (test_name))
        self.model.eval()

        # Step 1: Find the optimal threshold on validation set
        _, _ = self.eval_model(test=test, test_idx=test_idx)
        val_probs = torch.tensor(self.probs_list)
        val_labels = torch.tensor(self.labels_list)
        if len(self.id_list) != 0:
            val_ids = torch.tensor(self.id_list)
        else:
            val_ids = torch.zeros_like(val_labels) - 1
        val_preds = (val_probs > threshold).long()

        self._export_preds(val_ids,
                           val_probs,
                           val_preds,
                           labels=val_labels,
                           file_postfix="_%s_preds.csv" % test_name)

        LOGGER.info("Finished export of %s predictions" % test_name)

    def _export_preds(self,
                      ids,
                      probs,
                      preds,
                      labels=None,
                      file_postfix="_preds.csv"):
        file_string = "id,proba,label%s\n" % (",gt"
                                              if labels is not None else "")
        for i in range(ids.shape[0]):
            file_string += "%i,%f,%i" % (ids[i].item(), probs[i].item(),
                                         preds[i].item())
            if labels is not None:
                file_string += ",%i" % labels[i].item()
            file_string += "\n"
        filepath = os.path.join(
            self.config['model_path'],
            self.config['model_save_name'].rsplit(".", 1)[0] + file_postfix)
        with open(filepath, "w") as f:
            f.write(file_string)
        wandb.save(filepath)  #Upload file to wandb

    def check_early_stopping(self):
        self.this_metric = self.val_loss if self.config[
            'optimize_for'] == 'loss' else self.val_metrics[
                self.config['optimize_for']]
        self.current_best = self.best_val_loss if self.config[
            'optimize_for'] == 'loss' else self.best_val_metrics[
                self.config['optimize_for']]

        new_best = self.this_metric < self.current_best if self.config[
            'optimize_for'] == 'loss' else self.this_metric > self.current_best
        if new_best:
            LOGGER.info("New High Score! Saving model...")
            self.best_val_metrics = self.val_metrics
            self.best_val_loss = self.val_loss
            wandb.log({
                'Best val metrics': self.best_val_metrics,
                'Best val loss': self.best_val_loss
            })

            if not self.config["no_model_checkpoints"]:
                self.model_saver.save(self.model)

        ### Stopping Criteria based on patience and change-in-metric-threshold ###
        diff = self.current_best - \
            self.this_metric if self.config['optimize_for'] == 'loss' else self.this_metric - \
            self.current_best
        if diff < self.config['early_stop_thresh']:
            self.not_improved += 1
            if self.not_improved >= self.config['patience']:
                self.terminate_training = True
        else:
            self.not_improved = 0
        LOGGER.info("current patience: {}".format(self.not_improved))

    def train_epoch_step(self):
        self.model.train()
        lr = self.scheduler.get_last_lr()
        self.total_iters += self.iters + 1
        self.probs_list = [
            pred for batch_pred in self.probs_list for pred in batch_pred
        ]
        self.labels_list = [
            label for batch_labels in self.labels_list
            for label in batch_labels
        ]

        # Evaluate on train set
        self.train_metrics = standard_metrics(torch.tensor(self.probs_list),
                                              torch.tensor(self.labels_list),
                                              add_optimal_acc=True)
        log_tensorboard(self.config,
                        self.config['writer'],
                        self.model,
                        self.epoch,
                        self.iters,
                        self.total_iters,
                        self.loss_list,
                        self.train_metrics,
                        lr[0],
                        loss_only=False,
                        val=False)
        self.train_loss = self.loss_list[:]

        # Evaluate on dev set
        val_time = time.time()
        self.val_metrics, self.val_loss = self.eval_model()
        self.config['writer'].add_scalar("Stats/time_validation",
                                         time.time() - val_time,
                                         self.total_iters)

        # print stats
        print_stats(self.config, self.epoch, self.train_metrics,
                    self.train_loss, self.val_metrics, self.val_loss,
                    self.start, lr[0])

        # log validation stats in tensorboard
        log_tensorboard(self.config,
                        self.config['writer'],
                        self.model,
                        self.epoch,
                        self.iters,
                        self.total_iters,
                        self.val_loss,
                        self.val_metrics,
                        lr[0],
                        loss_only=False,
                        val=True)

        # Check for early stopping criteria
        self.check_early_stopping()
        self.probs_list = []
        self.preds_list = []
        self.labels_list = []
        self.loss_list = []
        self.id_list = []

        self.train_loss = sum(self.train_loss) / len(self.train_loss)
        del self.val_metrics
        del self.val_loss

    def end_training(self):
        # Termination message
        print("\n" + "-" * 100)
        if self.terminate_training:
            LOGGER.info(
                "Training terminated early because the Validation {} did not improve for  {}  epochs"
                .format(self.config['optimize_for'], self.config['patience']))
        else:
            LOGGER.info(
                "Maximum epochs of {} reached. Finished training !!".format(
                    self.config['max_epoch']))

        print_test_stats(self.best_val_metrics, test=False)

        print("-" * 50 + "\n\t\tEvaluating on test set\n" + "-" * 50)
        if not self.config["no_model_checkpoints"]:
            if os.path.isfile(self.model_file):
                self.load_model()
                self.model.to(self.device)
            else:
                raise ValueError(
                    "No Saved model state_dict found for the chosen model...!!! \nAborting evaluation on test set..."
                    .format(self.config['model_name']))

            self.export_val_predictions(
            )  # Runs evaluation, no need to run it again here
            val_probs = torch.tensor(self.probs_list)
            val_labels = torch.tensor(self.labels_list)
            threshold = 0.5  # the default threshelod for binary classification
            # Uncomment below line if you have implemented this optional feature
            # threshold = find_optimal_threshold(val_probs, val_labels, metric="accuracy")
            best_val_metrics = standard_metrics(val_probs,
                                                val_labels,
                                                threshold=threshold,
                                                add_aucroc=False)
            LOGGER.info(
                "Optimal threshold on validation dataset: %.4f (accuracy=%4.2f%%)"
                % (threshold, 100.0 * best_val_metrics["accuracy"]))

            # Testing is in the standard form not possible, as we do not have any labels (gives an error in standard_metrics)
            # Instead, we should write out the predictions in the form of the leaderboard
            self.test_metrics = dict()
            for test_idx in range(len(self.config['test_loader'])):
                test_name = self.config['test_loader'][test_idx].dataset.name
                LOGGER.info("Export and testing on %s..." % test_name)
                if hasattr(self.config['test_loader'][test_idx].dataset, "data") and \
                   hasattr(self.config['test_loader'][test_idx].dataset.data, "labels") and \
                   self.config['test_loader'][test_idx].dataset.data.labels[0] == -1:  # Step 1: Find the optimal threshold on validation set
                    self.export_test_predictions(test_idx=test_idx,
                                                 threshold=threshold)
                    self.test_metrics[test_name] = dict()
                else:
                    test_idx_metrics, _ = self.eval_model(test=True,
                                                          test_idx=test_idx)
                    self.test_metrics[test_name] = test_idx_metrics
                    print_test_stats(test_idx_metrics, test=True)
                    self.export_val_predictions(test=True,
                                                test_idx=test_idx,
                                                threshold=threshold)
        else:
            LOGGER.info(
                "No model checkpoints were saved. Hence, testing will be skipped."
            )
            self.test_metrics = dict()

        self.export_metrics()

        self.config['writer'].close()

        if self.config['remove_checkpoints']:
            LOGGER.info("Removing checkpoint %s..." % self.model_file)
            os.remove(self.model_file)

    def export_metrics(self):
        metric_export_file = os.path.join(
            self.config['model_path'],
            self.config['model_save_name'].rsplit(".", 1)[0] + "_metrics.json")
        metric_dict = {}
        metric_dict["dev"] = self.best_val_metrics
        metric_dict["dev"]["loss"] = self.best_val_loss
        metric_dict["train"] = self.train_metrics
        metric_dict["train"]["loss"] = sum(
            self.train_loss) / len(self.train_loss) if isinstance(
                self.train_loss, list) else self.train_loss
        if hasattr(self, "test_metrics") and len(self.test_metrics) > 0:
            metric_dict["test"] = self.test_metrics

        with open(metric_export_file, "w") as f:
            json.dump(metric_dict, f, indent=4)

    def train_main(self, cache=False):
        print("\n\n" + "=" * 100 + "\n\t\t\t\t\t Training Network\n" +
              "=" * 100)

        self.start = time.time()
        print("\nBeginning training at:  {} \n".format(
            datetime.datetime.now()))

        self.model.to(self.device)

        for self.epoch in range(self.start_epoch,
                                self.config['max_epoch'] + 1):
            train_times = []
            for self.iters, self.batch in enumerate(
                    self.config['train_loader']):
                self.model.train()

                iter_time = time.time()
                self.batch = self.batch_to_device(self.batch)
                self.train_iter_step()
                train_times.append(time.time() - iter_time)

                # Loss only logging
                if (self.total_iters + self.iters +
                        1) % self.config['log_every'] == 0:
                    log_tensorboard(self.config,
                                    self.config['writer'],
                                    self.model,
                                    self.epoch,
                                    self.iters,
                                    self.total_iters,
                                    self.short_loss_list,
                                    loss_only=True,
                                    val=False)
                    self.config['writer'].add_scalar(
                        'Stats/time_per_train_iter', mean(train_times),
                        (self.iters + self.total_iters + 1))
                    self.config['writer'].add_scalar(
                        'Stats/learning_rate',
                        self.scheduler.get_last_lr()[0],
                        (self.iters + self.total_iters + 1))
                    train_times = []
                    self.short_loss_list = []
            self.train_epoch_step()

            if self.terminate_training:
                break

        self.end_training()
        return self.best_val_metrics, self.test_metrics

    def batch_to_device(self, batch):
        batch = {
            k: (v.to(self.device) if isinstance(v, torch.Tensor) else v)
            for k, v in batch.items()
        }
        return batch

    def eval_iter_step(self, iters, batch, test):
        # Forward pass
        preds = self.model(img_feat=batch['img_feat'],
                           img_pos_feat=batch['img_pos_feat'],
                           input_ids=batch['input_ids'],
                           position_ids=batch['position_ids'],
                           attention_mask=batch['attn_mask'],
                           gather_index=batch['gather_index'],
                           output_all_encoded_layers=False,
                           gender_race_probs=batch['gender_race_probs'])
        self.calculate_loss(preds, batch['labels'], grad_step=False)

    def train_iter_step(self):
        # Forward pass
        self.preds = self.model(
            img_feat=self.batch['img_feat'],
            img_pos_feat=self.batch['img_pos_feat'],
            input_ids=self.batch['input_ids'],
            position_ids=self.batch['position_ids'],
            attention_mask=self.batch['attn_mask'],
            gather_index=self.batch['gather_index'],
            output_all_encoded_layers=False,
            gender_race_probs=self.batch['gender_race_probs'])
        self.calculate_loss(self.preds, self.batch['labels'], grad_step=True)

    def test_iter_step(self, batch):
        # Forward pass
        preds = self.model(img_feat=batch['img_feat'],
                           img_pos_feat=batch['img_pos_feat'],
                           input_ids=batch['input_ids'],
                           position_ids=batch['position_ids'],
                           attention_mask=batch['attn_mask'],
                           gather_index=batch['gather_index'],
                           output_all_encoded_layers=False,
                           gender_race_probs=batch['gender_race_probs'])
        return preds.squeeze()
Beispiel #2
0
class TrainerUniter(TrainerTemplate):
    def init_model(self):
        if self.pretrained_model_file:
            checkpoint = torch.load(self.pretrained_model_file)
            LOGGER.info('Using pretrained UNITER base model {}'.format(
                self.pretrained_model_file))
            base_model = UniterForPretraining.from_pretrained(
                self.config['config'],
                state_dict=checkpoint['model_state_dict'],
                img_dim=IMG_DIM,
                img_label_dim=IMG_LABEL_DIM)
            self.model = MemeUniter(
                uniter_model=base_model.uniter,
                hidden_size=base_model.uniter.config.hidden_size,
                n_classes=self.config['n_classes'])
        else:
            self.load_model()

    def load_model(self):
        # Load pretrained model
        if self.model_file:
            checkpoint = torch.load(self.model_file)
            LOGGER.info('Using UNITER model {}'.format(self.model_file))
        else:
            checkpoint = {}

        uniter_config = UniterConfig.from_json_file(self.config['config'])
        uniter_model = UniterModel(uniter_config, img_dim=IMG_DIM)

        self.model = MemeUniter(uniter_model=uniter_model,
                                hidden_size=uniter_model.config.hidden_size,
                                n_classes=self.config['n_classes'])
        self.model.load_state_dict(checkpoint['model_state_dict'])

    def eval_iter_step(self, iters, batch, test):
        # Forward pass
        preds = self.model(img_feat=batch['img_feat'],
                           img_pos_feat=batch['img_pos_feat'],
                           input_ids=batch['input_ids'],
                           position_ids=batch['position_ids'],
                           attention_mask=batch['attn_mask'],
                           gather_index=batch['gather_index'],
                           output_all_encoded_layers=False)
        self.calculate_loss(preds, batch['labels'], grad_step=False)

    def train_iter_step(self):
        # Forward pass
        self.preds = self.model(img_feat=self.batch['img_feat'],
                                img_pos_feat=self.batch['img_pos_feat'],
                                input_ids=self.batch['input_ids'],
                                position_ids=self.batch['position_ids'],
                                attention_mask=self.batch['attn_mask'],
                                gather_index=self.batch['gather_index'],
                                output_all_encoded_layers=False)
        self.calculate_loss(self.preds, self.batch['labels'], grad_step=True)

    def test_iter_step(self, batch):
        # Forward pass
        preds = self.model(img_feat=batch['img_feat'],
                           img_pos_feat=batch['img_pos_feat'],
                           input_ids=batch['input_ids'],
                           position_ids=batch['position_ids'],
                           attention_mask=batch['attn_mask'],
                           gather_index=batch['gather_index'],
                           output_all_encoded_layers=False)
        return preds.squeeze()