Exemplo n.º 1
0
def load_and_get_model_for_inference(trained_model_arch,
                                     trained_checkpoint_dir, filetype,
                                     input_shape, num_classes):
    model_factory = ModelFactory()
    model = model_factory.get_model(
        trained_model_arch,
        input_shape,
        is_training=False,
        num_classes=num_classes,
        learning_rate=0.001)  # A dummy learning rate since it is test mode.
    # The ModelCheckpoint in train pipeline saves the weights inside the checkpoint directory as follows.
    if filetype == '.h5':
        weights_path = trained_checkpoint_dir + "best_model_dir-auc.h5"
        model = tf.keras.models.load_model(weights_path)
    elif filetype == 'tf':
        weights_path = os.path.join(trained_checkpoint_dir, "variables",
                                    "variables")
        model.load_weights(weights_path)
    else:
        raise ValueError(
            "The provided saved model filetype not recognized: %s" % filetype)

    print(
        "The model has been created and the weights have been loaded from: %s"
        % weights_path)
    model.summary()
    return model
Exemplo n.º 2
0
def load_model():
    model_file_path = 'src/best_weights_1555982768.7076797.h5'
    #model_file_path = 'best_weights_1555982768.7076797.h5'
    model_factory = ModelFactory()
    model = model_factory.get_model(class_names,
                                    model_name=model_type,
                                    use_base_weights=False,
                                    weights_path=model_file_path,
                                    input_shape=(img_height, img_width, 3))
    optimizer = keras.optimizers.Adam(lr=1e-3, beta_1=0.9, beta_2=0.999)
    model.compile(optimizer=optimizer,
                  loss="binary_crossentropy",
                  metrics=["accuracy", "binary_accuracy"])
    model.load_weights(model_file_path)

    return model
def train(train_metadata_file_path,
          val_metadata_file_path,
          images_dir_path,
          out_dir,
          model_arch,
          num_classes,
          label_name=None,
          sequence_image_count=1,
          data_pipeline_mode="mode_flat_all",
          class_weight=None,
          whole_epochs=100,
          batch_size=32,
          learning_rate=0.001,
          patience=2,
          min_delta_auc=0.01,
          input_size=(224, 224, 3)):
    """
    Train a VGG16 model based on single image.

    :param train_metadata_file_path: The path to the metadata '.csv' file containing training image names.
    :param val_metadata_file_path: The path to the metadata '.csv' file containing validation image names.
    :param images_dir_path: The path containing the images.
    :param out_dir: The path to which the saved models need to be written.
    :param model_arch: The model architecture provided as string, which are present in the 'models' module.
    :param num_classes: The number of classes present in the data. If num_classes=1, it requires the 'label_name'.
    :param label_name: Required if num_classes=1. The name of the label to pick from the data.
    :param sequence_image_count: The number of images in the sequence dataset. Default: 1.
    :param data_pipeline_mode: The mode of the data pipeline. Default: "mode_flat_all".
    :param class_weight: The class_weights for imbalanced data. Example: {0: 1.0, 1: 0.5}, if class "0" is twice less
        represented than class "1" in your data. Default: None.
    :param whole_epochs: The maximum number of epochs to be trained. Note that the model maybe early-stopped. Default: 100.
    :param batch_size: The batch size used for the data. Ensure that it fits within the GPU memory. Default: 32.
    :param learning_rate: The constant learning rate to be used for the Adam optimizer. Default: 0.001.
    :param patience: The number of epochs (full train dataset) to wait before early stopping. Default: 2.
    :param min_delta_auc: The minimum delta of validation auc for early stopping after patience. Default: 0.01.
    :param input_size: The shape of the tensors returned by the data pipeline mode. Default: (224, 224, 3).

    """
    if num_classes == 1 and label_name is None:
        raise ValueError(
            "Since num_classes equals 1, the label_name must be provided.")

    train_data_epoch_subdivisions = 4
    early_stop_monitor = "val_auc"
    early_stop_min_delta = min_delta_auc
    early_stop_patience = patience * train_data_epoch_subdivisions  # One run through the train dataset.
    prefetch_buffer_size = 3  # Can be also be set to tf.data.experimental.AUTOTUNE

    os.makedirs(out_dir)

    # Build model architecture.
    model_factory = ModelFactory()
    model = model_factory.get_model(model_arch,
                                    input_size,
                                    is_training=True,
                                    num_classes=num_classes,
                                    learning_rate=learning_rate)
    print("Created the model architecture: %s" % model.name)
    model.summary()

    # Prepare the training dataset.
    print("Preparing training and validation datasets.")
    train_data_pipeline = PipelineGenerator(
        train_metadata_file_path,
        images_dir_path,  # XXX: This function calls requires this path to end with slash.
        # This needs to be handled in the PipelineGenerator.
        is_training=True,
        sequence_image_count=sequence_image_count,
        label_name=label_name,
        mode=data_pipeline_mode)
    train_dataset = train_data_pipeline.get_pipeline()
    train_dataset = train_dataset.batch(batch_size).prefetch(
        prefetch_buffer_size)

    # Prepare the validation dataset
    val_data_pipeline = PipelineGenerator(
        val_metadata_file_path,
        images_dir_path,
        is_training=False,
        sequence_image_count=sequence_image_count,
        label_name=label_name,
        mode=data_pipeline_mode)
    val_dataset = val_data_pipeline.get_pipeline()
    val_dataset = val_dataset.batch(batch_size).prefetch(prefetch_buffer_size)

    # TODO: Find a way to log the activation maps, either during training, or after the training has completed.

    # Prepare the callbacks.
    print("Preparing Tensorflow Keras Callbacks.")
    earlystop_callback = keras.callbacks.EarlyStopping(
        monitor=early_stop_monitor,
        min_delta=early_stop_min_delta,
        patience=early_stop_patience)

    # XXX: We use the HDF5 method to store the sequence models due to a bug in tensorflow TimeDistributed wrapper
    if data_pipeline_mode in PipelineGenerator.TIMESTEP_MODES:
        model_extension = ".h5"
    else:
        model_extension = ".ckpt"

    best_model_checkpoint_auc_callback = keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(out_dir, "best_model_dir-auc" + model_extension),
        mode='max',
        monitor='val_auc',
        save_best_only=True,
        save_weights_only=False,
        verbose=1)
    best_model_checkpoint_loss_callback = keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(out_dir,
                              "best_model_dir-loss" + model_extension),
        mode='min',
        monitor='val_loss',
        save_best_only=True,
        save_weights_only=False,
        verbose=1)

    tensorboard_callback = keras.callbacks.TensorBoard(log_dir=os.path.join(
        out_dir, "TBGraph"),
                                                       write_graph=True,
                                                       write_images=True)

    callbacks = [
        earlystop_callback, best_model_checkpoint_auc_callback,
        best_model_checkpoint_loss_callback, tensorboard_callback
    ]

    # Start model training.
    # Defining an 'epoch' to be a quarter of the train dataset.
    num_train_samples = train_data_pipeline.get_size()
    num_val_samples = val_data_pipeline.get_size()
    # Number of batches per one run through the train dataset.
    num_training_steps_per_whole_dataset = int(num_train_samples / batch_size)
    num_val_steps_per_whole_dataset = int(num_val_samples / batch_size)
    steps_per_epoch = int(num_training_steps_per_whole_dataset /
                          train_data_epoch_subdivisions)
    max_num_epochs = int(whole_epochs * train_data_epoch_subdivisions)
    max_train_steps = int(max_num_epochs * steps_per_epoch)

    print(
        "Number of train samples: %s, which correspond to  ~%s batches for one complete run through the "
        "train dataset. Number of validation samples: %s, which correspond to ~%s batches for complete iteration. "
        "Considering a 1/%s fraction of the train dataset as an epoch (steps_per_epoch: %s) "
        "after which validation and model checkpoints are saved. Running training for a maximum of %s steps, "
        "which correspond to max_num_epochs: %s (whole_epochs: %s). "
        "Early stopping has been set based on '%s' of min_delta of %s with a patience of %s."
        % (num_train_samples, num_training_steps_per_whole_dataset,
           num_val_samples, num_val_steps_per_whole_dataset,
           train_data_epoch_subdivisions, steps_per_epoch, max_train_steps,
           max_num_epochs, whole_epochs, early_stop_monitor,
           early_stop_min_delta, early_stop_patience))

    print("\nStarting the model training.")
    start_time = time.time()

    model.fit(train_dataset,
              epochs=max_num_epochs,
              steps_per_epoch=steps_per_epoch,
              validation_data=val_dataset,
              validation_steps=num_val_steps_per_whole_dataset,
              callbacks=callbacks,
              class_weight=class_weight)

    time_taken = time.time() - start_time
    print(
        "Training completed and the output has been saved in %s. Time taken: %s seconds."
        % (out_dir, time_taken))
Exemplo n.º 4
0
class MaeBird(QMainWindow, Ui_MainWindow):
    def __init__(self, parent=None):
        super(MaeBird, self).__init__(parent)
        
        self.models = ModelFactory()

        self.table = None
        self.languages = {'perimary': 'Fin', 'secondary': 'Eng', 
                         'tertiary': 'Swe'}
        
        self.dbtype = __DB__
        self.dbfile = None
        self.db = None
        
        self.matches = []
        self.currentsearchitem = 0
        
        self.fullscreen = False
        self.setupUi(self)
        
        self.setWindowTitle(__APPNAME__ + ' ' + __VERSION__)
        
        # TODO: loading settings should be moved to a separate method
        settings = QSettings()
        
        # Set up logging
        loggingdir = settings.value("Logging/loggingDir")
        if loggingdir is None:
            loggingdir = __USER_DATA_DIR__
        self.logger = Logger('root', loggingdir=loggingdir)
        if settings.value("Settings/debugging"):
            self.logger.debugging = int(settings.value("Settings/debugging"))
            self.logger.debug('Logging initialized')
        
        # Try to load previous session
        if settings.value("Settings/saveSettings"):
            self.saveSettings = int(settings.value("Settings/saveSettings"))
        else:
            self.saveSettings = 1
                      
        if self.saveSettings:
            QTimer.singleShot(0, self.load_initial_data)
            #QTimer.singleShot(0, self.load_initial_model)
        
        self.header = self.tableView.horizontalHeader()
        self.header.sectionDoubleClicked.connect(self.sort_table)
        
        self.search.textEdited.connect(self.update_ui)
        self.search.setFocus()
        self.searchNextButton.clicked.connect(self.update_ui)
        self.searchPrevButton.clicked.connect(self.update_ui)
        
        self.tableView.pressed.connect(self.update_ui)
        
        self.tableView.doubleClicked.connect(
                    lambda: self.handle_observation(ObservationDialog.SHOW))
        self.addButton.clicked.connect(
                    lambda: self.handle_observation(ObservationDialog.ADD))
        self.deleteButton.clicked.connect(
                    lambda: self.handle_observation(ObservationDialog.DELETE))
        
    def closeEvent(self, event):
        settings = QSettings()
        if self.saveSettings:
            db = self.dbfile if self.db is not None else ''
            settings.setValue("Database/LastDb", db)
            
            if self.tableView.model() is not None:
                settings.setValue("Database/DefaultModel", self.tableView.model().name)
            
                visible_fields = [not bool(self.tableView.isColumnHidden(i)) for i in range(0, self.tableView.model().columnCount())]
                settings.setValue("Database/visibleFields", visible_fields)
            settings.setValue("Settings/debugging", int(self.logger.debugging))
        
        settings.setValue("Settings/saveSettings", int(self.saveSettings))
    
    def load_initial_data(self):
        settings = QSettings()
        dbfile = unicode(settings.value("Database/LastDb"))
        modelname = unicode(settings.value("Database/DefaultModel"))
        if dbfile and QFile.exists(dbfile):
            self.load_db(dbfile, modelname=modelname)
            self.logger.debug("Loaded database %s with model %s" % (dbfile,
                                                                     modelname))
        
        if settings.value("Database/visibleFields"):
            visible_fields = [item for item in settings.value("Database/visibleFields")]
        
            # FIXME: in absence of QVariant, deal with values
            visible_fields = [False if item == 'false' else True for item in visible_fields]
            if not all(visible_fields):
                self.logger.debug("Hiding fields %s" % visible_fields)
            self.show_fields(visible_fields)

    def load_db(self, dbname, modelname=None):
        self.db = QSqlDatabase.addDatabase(self.dbtype)
        self.db.setDatabaseName(dbname)
        if not self.db.open():
            QMessageBox.warning(self, "Batabase connection",
                "Database Error: %s" % (self.db.lastError().text()))
            return
        self.dbfile = dbname
        
        if modelname not in self.models.model_names:
            modeldlg = ModelDialog(self.models.model_names)
            if modeldlg.exec_():
                modelname = modeldlg.selected_model()
        
        if modelname:
            self.load_model(modelname)
    
    def load_model(self, modelname):
        ''' Loads a specific database model and sets it to view.  
        '''
        try:
            model = self.models.get_model(modelname)
        except NotImplementedError, e:
            QMessageBox.warning(self, "Database model",
                "Database Model Error: %s" % str(e))
            return
        self.tableView.setModel(model(self))
        self.tableView.setItemDelegate(QSqlRelationalDelegate(self))
        self.tableView.setSelectionMode(QTableView.SingleSelection)
        self.tableView.setSelectionBehavior(QTableView.SelectRows)
        self.tableView.setColumnHidden(0, True)
        self.tableView.resizeColumnsToContents()
        self.update_ui()