Exemple #1
0
    def build_model(self, monitor=Monitor.FILTERED_MEAN_RANK):
        """function to build the model"""
        if self.config.load_from_data is not None:
            self.load_model(self.config.load_from_data)

        self.evaluator = Evaluator(self.model, self.config)

        self.model.to(self.config.device)

        if self.config.optimizer == "adam":
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.config.learning_rate,
            )
        elif self.config.optimizer == "sgd":
            self.optimizer = optim.SGD(
                self.model.parameters(),
                lr=self.config.learning_rate,
            )
        elif self.config.optimizer == "adagrad":
            self.optimizer = optim.Adagrad(
                self.model.parameters(),
                lr=self.config.learning_rate,
            )
        elif self.config.optimizer == "rms":
            self.optimizer = optim.RMSprop(
                self.model.parameters(),
                lr=self.config.learning_rate,
            )
        else:
            raise NotImplementedError("No support for %s optimizer" % self.config.optimizer)

        self.config.summary()

        self.early_stopper = EarlyStopper(self.config.patience, monitor)
Exemple #2
0
    def tune_model(self):
        """Function to tune the model."""
        current_loss = float("inf")

        self.generator = Generator(self.model, self.config)
        self.evaluator = Evaluator(self.model, self.config, tuning=True)

        for cur_epoch_idx in range(self.config.epochs):
            current_loss = self.train_model_epoch(cur_epoch_idx, tuning=True)

        self.evaluator.full_test(cur_epoch_idx)

        self.generator.stop()

        return current_loss
Exemple #3
0
    def enter_interactive_mode(self):
        self.build_model()
        self.load_model()

        self.evaluator = Evaluator(self.model)
        self._logger.info("""The training/loading of the model has finished!
                                    Now enter interactive mode :)
                                    -----
                                    Example 1: trainer.infer_tails(1,10,topk=5)""")
        self.infer_tails(1,10,topk=5)

        self._logger.info("""-----
                                    Example 2: trainer.infer_heads(10,20,topk=5)""")
        self.infer_heads(10,20,topk=5)

        self._logger.info("""-----
                                    Example 3: trainer.infer_rels(1,20,topk=5)""")
        self.infer_rels(1,20,topk=5)
Exemple #4
0
    def train_model(self, monitor=Monitor.FILTERED_MEAN_RANK):
        """Function to train the model."""
        self.generator = Generator(self.model)
        self.evaluator = Evaluator(self.model)

        if self.config.loadFromData:
            self.load_model()
        
        for cur_epoch_idx in range(self.config.epochs):
            self._logger.info("Epoch[%d/%d]" % (cur_epoch_idx, self.config.epochs))
            
            self.train_model_epoch(cur_epoch_idx)

            if cur_epoch_idx % self.config.test_step == 0:
                metrics = self.evaluator.mini_test(cur_epoch_idx)
                              
                if self.early_stopper.should_stop(metrics):
                    ### Early Stop Mechanism
                    ### start to check if the metric is still improving after each mini-test. 
                    ### Example, if test_step == 5, the trainer will check metrics every 5 epoch.
                    break

        self.evaluator.full_test(cur_epoch_idx)
        self.evaluator.metric_calculator.save_test_summary(self.model.model_name)

        self.generator.stop()
        self.save_training_result()

        if self.config.save_model:
            self.save_model()

        if self.config.disp_result:
            self.display()

        if self.config.disp_summary:
            self.config.summary()
            self.config.summary_hyperparameter(self.model.model_name)

        self.export_embeddings()

        return cur_epoch_idx # the runned epoches.