Beispiel #1
0
def train_model(filename_train_validation_set,
                filename_labels_train_validation_set, filename_sample_weights,
                filter_density1, filter_density2, pool_n_row, pool_n_col,
                dropout, input_shape, file_path_model, filename_log):
    """
    train final model save to model path
    """

    filenames_train, Y_train, sample_weights_train, \
    filenames_validation, Y_validation, sample_weights_validation, \
    filenames_features, Y_train_validation, sample_weights, class_weights = \
        load_data_jingju(filename_labels_train_validation_set,
                         filename_sample_weights)

    model_0 = jordi_model(filter_density1, filter_density2, pool_n_row,
                          pool_n_col, dropout, input_shape, 'temporal')

    batch_size = 128
    patience = 10

    print(model_0.count_params())

    model_train(model_0, batch_size, patience, input_shape,
                filename_train_validation_set, filenames_train, Y_train,
                sample_weights_train, filenames_validation, Y_validation,
                sample_weights_validation, filenames_features,
                Y_train_validation, sample_weights, class_weights,
                file_path_model, filename_log)
def train_model(filename_train_validation_set,
                filename_labels_train_validation_set, filename_sample_weights,
                filter_density, dropout, input_shape, file_path_model,
                filename_log):
    """
    train final model save to model path
    """

    filenames_train, Y_train, sample_weights_train, \
    filenames_validation, Y_validation, sample_weights_validation, \
    filenames_features, Y_train_validation, sample_weights, class_weights = \
        load_data_jingju(filename_labels_train_validation_set,
                         filename_sample_weights)

    model_0 = jan(filter_density=filter_density,
                  dropout=dropout,
                  input_shape=input_shape,
                  batchNorm=True)

    batch_size = 128
    patience = 10

    print(model_0.count_params())

    model_train(model_0, batch_size, patience, input_shape,
                filename_train_validation_set, filenames_train, Y_train,
                sample_weights_train, filenames_validation, Y_validation,
                sample_weights_validation, filenames_features,
                Y_train_validation, sample_weights, class_weights,
                file_path_model, filename_log)
def finetune_model_validation(filename_train_validation_set,
                              filename_labels_train_validation_set,
                              filename_sample_weights,
                              filter_density,
                              dropout,
                              input_shape,
                              file_path_model,
                              filename_log,
                              model_name,
                              path_model,
                              channel=1):
    """
    train model with validation
    """

    filenames_train, Y_train, sample_weights_train, \
    filenames_validation, Y_validation, sample_weights_validation, \
    filenames_features, Y_train_validation, sample_weights, class_weights = \
        load_data_jingju(filename_labels_train_validation_set,
                         filename_sample_weights)

    # load pretrained model
    model_pretrained = load_model(filepath=path_model)

    if model_name == 'retrained':
        model = model_pretrained
        multi_inputs = False
    else:
        model = model_switcher(model_name=model_name,
                               filter_density=filter_density,
                               dropout=dropout,
                               input_shape=input_shape,
                               channel=channel,
                               model_pretrained=model_pretrained,
                               activation_dense='sigmoid')
        multi_inputs = True

    batch_size = 256
    patience = 15

    model_train_validation(model,
                           batch_size,
                           patience,
                           input_shape,
                           filename_train_validation_set,
                           filenames_train,
                           Y_train,
                           sample_weights_train,
                           filenames_validation,
                           Y_validation,
                           sample_weights_validation,
                           class_weights,
                           file_path_model,
                           filename_log,
                           channel,
                           multi_inputs=multi_inputs)
def train_model_validation(filename_train_validation_set,
                           filename_labels_train_validation_set,
                           filename_sample_weights,
                           filter_density,
                           dropout,
                           input_shape,
                           file_path_model,
                           filename_log,
                           model_name='baseline',
                           activation_dense='sigmoid',
                           channel=1):
    """
    train model with validation
    """

    filenames_train, Y_train, sample_weights_train, \
    filenames_validation, Y_validation, sample_weights_validation, \
    filenames_features, Y_train_validation, sample_weights, class_weights = \
        load_data_jingju(filename_labels_train_validation_set,
                         filename_sample_weights)

    model_0 = model_switcher(model_name=model_name,
                             filter_density=filter_density,
                             dropout=dropout,
                             input_shape=input_shape,
                             channel=channel,
                             activation_dense=activation_dense)

    batch_size = 256
    patience = 15

    model_train_validation(model_0, batch_size, patience, input_shape,
                           filename_train_validation_set, filenames_train,
                           Y_train, sample_weights_train, filenames_validation,
                           Y_validation, sample_weights_validation,
                           class_weights, file_path_model, filename_log,
                           channel)