Exemple #1
0
    def train(self, epoch, draw_graph=False):

        print("Start the Cross validation for",
              self.train_collection.get_n_folds(), "folds")

        temp_dir = tempfile.mkdtemp()

        try:
            # save model init state
            save_model_weights(os.path.join(temp_dir, "temp_weights.h5"),
                               self.model)
            best_test_scores = []
            for i, collections in enumerate(self.train_collection.generator()):
                print("Prepare FOLD", i)

                train_collection, test_collection = collections

                # show baseline metrics over the previous ranking order
                pre_metrics = test_collection.evaluate_pre_rerank()
                print("Evaluation of the original ranking order")
                for n, m in pre_metrics:
                    print(n, m)

                # reset all the states
                set_random_seed()
                K.clear_session()

                # load model init state
                load_model_weights(os.path.join(temp_dir, "temp_weights.h5"),
                                   self.model)

                self.wandb_config["name"] = "Fold_0" + str(
                    i) + "_" + self.wandb_config["name"]

                # create evaluation callback
                if self.wandb_config is not None:
                    wandb_val_logger = WandBValidationLogger(
                        wandb_args=self.wandb_config,
                        steps_per_epoch=train_collection.get_steps(),
                        validation_collection=test_collection)
                else:
                    raise KeyError("Please use wandb for now!!!")

                best_test_scores.append(wandb_val_logger.current_best)

                callbacks = [wandb_val_logger] + self.callbacks

                print("Train and test FOLD", i)

                pairwise_training = PairwiseTraining(
                    model=self.model,
                    train_collection=train_collection,
                    loss=self.loss,
                    optimizer=self.optimizer,
                    callbacks=callbacks)

                pairwise_training.train(epoch, draw_graph=draw_graph)

            x_score = sum(best_test_scores) / len(best_test_scores)
            print("X validation best score:", x_score)
            wandb_val_logger.wandb.run.summary[
                "best_xval_" + wandb_val_logger.comparison_metric] = x_score

        except Exception as e:
            raise e  # maybe handle the exception in the future
        finally:
            # always remove the temp directory
            print("Remove {}".format(temp_dir))
            shutil.rmtree(temp_dir)
Exemple #2
0
 def on_train_start(self, training_obj):
     save_model_weights(os.path.join(self.temp_dir, "temp.h5"),
                        training_obj.model)