def train(self, epoch, draw_graph=False): print("Start the Cross validation for", self.train_collection.get_n_folds(), "folds") temp_dir = tempfile.mkdtemp() try: # save model init state save_model_weights(os.path.join(temp_dir, "temp_weights.h5"), self.model) best_test_scores = [] for i, collections in enumerate(self.train_collection.generator()): print("Prepare FOLD", i) train_collection, test_collection = collections # show baseline metrics over the previous ranking order pre_metrics = test_collection.evaluate_pre_rerank() print("Evaluation of the original ranking order") for n, m in pre_metrics: print(n, m) # reset all the states set_random_seed() K.clear_session() # load model init state load_model_weights(os.path.join(temp_dir, "temp_weights.h5"), self.model) self.wandb_config["name"] = "Fold_0" + str( i) + "_" + self.wandb_config["name"] # create evaluation callback if self.wandb_config is not None: wandb_val_logger = WandBValidationLogger( wandb_args=self.wandb_config, steps_per_epoch=train_collection.get_steps(), validation_collection=test_collection) else: raise KeyError("Please use wandb for now!!!") best_test_scores.append(wandb_val_logger.current_best) callbacks = [wandb_val_logger] + self.callbacks print("Train and test FOLD", i) pairwise_training = PairwiseTraining( model=self.model, train_collection=train_collection, loss=self.loss, optimizer=self.optimizer, callbacks=callbacks) pairwise_training.train(epoch, draw_graph=draw_graph) x_score = sum(best_test_scores) / len(best_test_scores) print("X validation best score:", x_score) wandb_val_logger.wandb.run.summary[ "best_xval_" + wandb_val_logger.comparison_metric] = x_score except Exception as e: raise e # maybe handle the exception in the future finally: # always remove the temp directory print("Remove {}".format(temp_dir)) shutil.rmtree(temp_dir)
def on_train_start(self, training_obj): save_model_weights(os.path.join(self.temp_dir, "temp.h5"), training_obj.model)