def _create_train_callbacks(self) -> List[Callback]:
        # TODO: do we want to use early stopping? if so, use the right chechpoint manager and set the correct
        #       `monitor` quantity (example: monitor='val_acc', mode='max')

        keras_callbacks = [
            ModelTrainingStatusTrackerCallback(self.training_status),
            ModelTrainingProgressLoggerCallback(self.config,
                                                self.training_status),
        ]
        if self.config.is_saving:
            keras_callbacks.append(
                ModelCheckpointSaverCallback(self,
                                             self.config.SAVE_EVERY_EPOCHS,
                                             self.logger))
            # keras_callbacks.append(
            #     ModelCheckpoint(self.config.get_entire_model_path(self.config.MODEL_SAVE_PATH), monitor='val_accuracy', verbose=1, save_best_only=True,
            #     mode='max'))
            # # save vocabs
            # model_save_path = self.config.MODEL_SAVE_PATH
            # model_save_dir = '/'.join(model_save_path.split('/')[:-1])
            # if not os.path.isdir(model_save_dir):
            #     os.makedirs(model_save_dir, exist_ok=True)
            # self.vocabs.save(self.config.get_vocabularies_path_from_model_path(model_save_path))
            # # save vocabs end
        if self.config.is_testing:
            keras_callbacks.append(ModelEvaluationCallback(self))
        if self.config.USE_TENSORBOARD:
            log_dir = "logs/scalars/train_" + common.now_str()
            tensorboard_callback = keras.callbacks.TensorBoard(
                log_dir=log_dir,
                update_freq=self.config.NUM_BATCHES_TO_LOG_PROGRESS *
                self.config.TRAIN_BATCH_SIZE)
            keras_callbacks.append(tensorboard_callback)
        return keras_callbacks
Esempio n. 2
0
    def _create_train_callbacks(self) -> List[Callback]:
        # TODO: do we want to use early stopping? if so, use the right chechpoint manager and set the correct
        #       `monitor` quantity (example: monitor='val_acc', mode='max')

        keras_callbacks = [
            ModelTrainingStatusTrackerCallback(self.training_status),
            ModelTrainingProgressLoggerCallback(self.config,
                                                self.training_status),
        ]
        if self.config.EARLY_STOPPING:
            keras_callbacks.append(
                tf.keras.callbacks.EarlyStopping(
                    monitor='loss', patience=self.config.PATIENCE))
        if self.config.is_saving:
            keras_callbacks.append(
                ModelCheckpointSaverCallback(self,
                                             self.config.SAVE_EVERY_EPOCHS,
                                             self.logger))
        if self.config.is_testing:
            keras_callbacks.append(ModelEvaluationCallback(self))
        if self.config.USE_TENSORBOARD:
            log_dir = "logs/scalars/train_" + common.now_str()
            tensorboard_callback = keras.callbacks.TensorBoard(
                log_dir=log_dir,
                update_freq=self.config.NUM_BATCHES_TO_LOG_PROGRESS *
                self.config.TRAIN_BATCH_SIZE)
            keras_callbacks.append(tensorboard_callback)
        return keras_callbacks