def main():

    ## Sets
    #train_set_path='../Sets/demoset.h5'
    #validation_set_path='../Sets/demoset.h5'
    #test_set_path='../Sets/demoset.h5'

    train_set_path = '../Sets/trainset.h5'
    validation_set_path = '../Sets/validationset1.h5'
    test_set_path = '../Sets/testset.h5'

    ## Database location
    DataBasePath = '../Database/lines'

    transferFLAG = False
    testFLAG = False

    batch_size = 16
    num_epochs = 250
    learning_rate = 0.0003
    num_epochs_before_validation = 10

    restore_checkpoint_at_epoch = 0

    import datetime
    now = now = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S")
    ## Log file
    if testFLAG:
        indicator = sys.argv[0].split('/')[-1].split('-')[0]
        files_path = './train-{}/eval-{}/'.format(indicator, now)
        log_path = './train-{}/eval-{}/log/'.format(indicator, now)
        log_file_path = './train-{}/eval-{}/log/log.txt'.format(indicator, now)
        models_path = './train-{}/models/'.format(indicator)
        TensorBoard_dir = './train-{}/eval-{}/TensorBoard_files/'.format(
            indicator, now)
        tf.gfile.MakeDirs(files_path)
        tf.gfile.MakeDirs(log_path)
        tf.gfile.MakeDirs(TensorBoard_dir)
        copyfile(sys.argv[0], files_path + '{}-'.format(now) + sys.argv[0])
        log_file_indicator = initialize_log(log_file_path, mode='w')
    elif restore_checkpoint_at_epoch == 0 or transferFLAG:
        files_path = './train-{}/'.format(now)
        log_path = './train-{}/log'.format(now)
        log_file_path = './train-{}/log/log.txt'.format(now)
        models_path = './train-{}/models/'.format(now)
        TensorBoard_dir = './train-{}/TensorBoard_files/'.format(now)
        tf.gfile.MakeDirs(files_path)
        tf.gfile.MakeDirs(log_path)
        tf.gfile.MakeDirs(models_path)
        tf.gfile.MakeDirs(TensorBoard_dir)
        copyfile(sys.argv[0], files_path + '{}-'.format(now) + sys.argv[0])
        copyfile(sys.argv[0], './{}-'.format(now) + sys.argv[0])
        log_file_indicator = initialize_log(log_file_path, mode='w')
    else:
        indicator = sys.argv[0].split('/')[-1].split('-')[0]
        log_path = './train-{}/log'.format(indicator)
        log_file_path = './train-{}/log/log.txt'.format(indicator)
        models_path = './train-{}/models/'.format(indicator)
        TensorBoard_dir = './train-{}/TensorBoard_files/'.format(indicator)
        log_file_indicator = initialize_log(log_file_path, mode='a')
        log_file_indicator.write(
            ('#' * 100 + '\n') * 5 +
            '\n\nRecovering after break or pause in epoch ' +
            str(restore_checkpoint_at_epoch) + '\n\n' + ('#' * 100 + '\n') * 5)

    if transferFLAG:
        model_for_transfer_path = './model_for_transfer'

    num_steps = ceil(num_epochs / num_epochs_before_validation)

    validSet, valid_imageHeight, valid_imageWidth, valid_labels = load_dataset(
        validation_set_path, DataBasePath, log_file_indicator)
    if not testFLAG:
        trainSet, train_imageHeight, train_imageWidth, train_labels = load_dataset(
            train_set_path, DataBasePath, log_file_indicator)
        imageHeight, labels = check_valid_and_test_sets(
            train_imageHeight, valid_imageHeight, train_imageHeight,
            train_labels, valid_labels, train_labels, log_file_indicator)
        train_writer = tf.summary.FileWriter(TensorBoard_dir + 'train_task')
        valid_vs_writer = tf.summary.FileWriter(TensorBoard_dir +
                                                'valid_task_validset')
        valid_ts_writer = tf.summary.FileWriter(TensorBoard_dir +
                                                'valid_task_trainset')
    else:
        testSet, test_imageHeight, test_imageWidth, test_labels = load_dataset(
            test_set_path, DataBasePath, log_file_indicator)
        imageHeight, labels = check_valid_and_test_sets(
            test_imageHeight, valid_imageHeight, test_imageHeight, test_labels,
            valid_labels, test_labels, log_file_indicator)
        test_writer = tf.summary.FileWriter(TensorBoard_dir + 'test_validset')
        valid_writer = tf.summary.FileWriter(TensorBoard_dir + 'valid_testset')
    log_file_indicator.flush()

    # The number of classes is the amount of labels plus 1 for blanck
    num_classes = len(labels) + 1

    train_start = time.time()
    network_train = Network()
    if transferFLAG:
        epoch = restore_checkpoint_at_epoch
        transfer(epoch, network_train, imageHeight, train_imageWidth,
                 num_classes, log_file_indicator, model_for_transfer_path,
                 models_path, train_writer)

    if not testFLAG:

        for step in range(
                ceil(restore_checkpoint_at_epoch /
                     num_epochs_before_validation), num_steps):

            train(step, network_train, num_epochs_before_validation,
                  batch_size, learning_rate, trainSet, imageHeight,
                  train_imageWidth, num_classes, log_file_indicator,
                  models_path, train_writer, transferFLAG)

            epoch = (step + 1) * num_epochs_before_validation - 1

            validation(epoch, network_train, batch_size, 'validation',
                       validSet, imageHeight, valid_imageWidth, labels,
                       num_classes, log_file_indicator, models_path,
                       valid_vs_writer)
            validation(epoch, network_train, batch_size, 'train', trainSet,
                       imageHeight, train_imageWidth, labels, num_classes,
                       log_file_indicator, models_path, valid_ts_writer)

        train_end = time.time()
        train_duration = train_end - train_start
        print('Training completed in: ' +
              seconds_to_days_hours_min_sec(train_duration))
        log_file_indicator.write(
            '\nTraining completed in: ' +
            seconds_to_days_hours_min_sec(train_duration, day_flag=True) +
            '\n')

    else:

        epoch = restore_checkpoint_at_epoch
        text = '\nEvaluating model at epoch {}...\n'.format(epoch)
        print(text)
        log_file_indicator.write(text)
        validation(epoch, network_train, batch_size, 'validation', validSet,
                   imageHeight, valid_imageWidth, labels, num_classes,
                   log_file_indicator, models_path, valid_writer)
        validation(epoch, network_train, batch_size, 'test', testSet,
                   imageHeight, test_imageWidth, labels, num_classes,
                   log_file_indicator, models_path, test_writer)

        test_end = time.time()
        test_duration = test_end - train_start
        print('Evaluation completed in: ' +
              seconds_to_days_hours_min_sec(test_duration))
        log_file_indicator.write(
            '\nEvaluation completed in: ' +
            seconds_to_days_hours_min_sec(test_duration, day_flag=True) + '\n')
    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        print(shape)
        print(len(shape))
        variable_parameters = 1
        for dim in shape:
            print(dim)
            variable_parameters *= dim.value
        print(variable_parameters)
        total_parameters += variable_parameters
    print(total_parameters)
    if not testFLAG:
        train_writer.close()
        valid_ts_writer.close()
        valid_vs_writer.close()
    else:
        test_writer.close()
        valid_writer.close()

    log_file_indicator.flush()
    log_file_indicator.close()
Exemple #2
0
from prefect import Parameter, utilities
from prefect.core.flow import Flow
from tasks import extract, pre_screen, transfer, load, post_load_orphan
import logging

log = utilities.logging.get_logger()
with Flow("{{Client}} Claims ETL") as flow:
  configuration_id = Parameter("configuration_id")

  # log.info(f"Running claims ETL with configuration {configuration_id}.")
  member_list = pre_screen(configuration_id)
  extracted_data = extract(member_list)
  patient_roster = transfer(extracted_data[0], extracted_data[1])
  l = load(patient_roster=patient_roster)
  post_load_orphan(member_number_list=member_list, upstream_tasks=[l])

flow.register()