Ejemplo n.º 1
0
 def cyclegan(self):
     self.tl_prebuilt_gan.withdraw()
     tf.reset_default_graph()
     cwd = os.getcwd()
     config_file = os.path.join(cwd, "prebuilt_configs/cyclegan.json")
     configs = load_config(config_file)
     self.controller.set_configs(configs)
Ejemplo n.º 2
0
 def densenet201(self):
     self.tl_prebuilt_cnn.withdraw()
     tf.reset_default_graph()
     cwd = os.getcwd()
     config_file = os.path.join(cwd, "prebuilt_configs/densenet201.json")
     configs = load_config(config_file)
     self.controller.set_configs(configs)
Ejemplo n.º 3
0
 def mobilenet_ssd(self):
     self.tl_prebuilt_bbd.withdraw()
     tf.reset_default_graph()
     cwd = os.getcwd()
     config_file = os.path.join(cwd, "prebuilt_configs/mobilenet_ssd.json")
     configs = load_config(config_file)
     self.controller.set_configs(configs)
Ejemplo n.º 4
0
 def inceptionresnetv2(self):
     self.tl_prebuilt_cnn.withdraw()
     tf.reset_default_graph()
     cwd = os.getcwd()
     config_file = os.path.join(cwd,
                                "prebuilt_configs/inceptionresnetv2.json")
     configs = load_config(config_file)
     self.controller.set_configs(configs)
Ejemplo n.º 5
0
def main(FLAGS):
    # set GPU devices to use
    # os.environ['CUDA_VISIBLE_DEVICES'] = '6,7'

    # define the experiments
    encoders = FLAGS.encoders.split(',')
    losses = FLAGS.losses.split(',')
    alpha = FLAGS.loss_param1.split(',')
    experiments = [encoders, losses, alpha]
    print(experiments)
    for experiment in itertools.product(*experiments):
        print(experiment)

    for experiment in itertools.product(*experiments):
        # switch to activate training session
        do_train = True

        # load the base configurations
        if experiment[0] == 'UNet':
            configs = load_config(os.path.join(FLAGS.base_configs_dir, 'unet2d.json'))
        elif experiment[0] == 'UNet3D':
            configs = load_config(os.path.join(FLAGS.base_configs_dir, 'unet3d.json'))
        elif experiment[0] == 'VGG16' or experiment[0] == 'VGG19':
            configs = load_config(os.path.join(FLAGS.base_configs_dir, 'vgg16_unet.json'))
        else:
            configs = load_config(os.path.join(FLAGS.base_configs_dir, 'xception_unet.json'))

        # apply some augmentation
        configs['augmentation']['apply_augmentation_switch'] = 'True'
        configs['augmentation']['width_shift'] = '0.25'
        configs['augmentation']['height_shift'] = '0.25'
        configs['augmentation']['rotation_range'] = '10'
        configs['augmentation']['zoom_range'] = '0.2'
        configs['augmentation']['shear_range'] = '0.15'
        configs['augmentation']['horizontal_flip'] = 'True'
        configs['augmentation']['vertical_flip'] = 'False'
        configs['augmentation']['rounds'] = '2'
        #configs['augmentation']['brightness_range'] = '(0.75, 1.0)'

        # perform some preprocessing
        configs['preprocessing']['categorical_switch'] = 'True'
        configs['preprocessing']['minimum_image_intensity'] = '0.0'
        configs['preprocessing']['maximum_image_intensity'] = '2048.0'
        configs['preprocessing']['categories'] = '{}'.format(FLAGS.classes)
        configs['preprocessing']['normalization_type'] = '{}'.format(FLAGS.normalization)

        # set the training configurations
        if FLAGS.num_gpus > 1:
            configs['training_configurations']['hardware'] = 'multi-gpu'
        else:
            configs['training_configurations']['hardware'] = 'gpu'
        configs['training_configurations']['number_of_gpus'] = '{}'.format(FLAGS.num_gpus)
        configs['training_configurations']['batch_size'] = '{}'.format(FLAGS.batch_size)
        configs['training_configurations']['epochs'] = '{}'.format(FLAGS.epochs)
        configs['training_configurations']['validation_split'] = '0.0'

        # set the learning rate
        configs['learning_rate_schedule']['learning_rate'] = '{}'.format(FLAGS.learning_rate)

        # turn on savers
        configs['save_configurations']['save_checkpoints_switch'] = 'True'
        configs['save_configurations']['save_csv_switch'] = 'True'

        # define data paths
        configs['paths']['train_X'] = FLAGS.train_X_path
        configs['paths']['train_y'] = FLAGS.train_y_path
        configs['paths']['validation_X'] = FLAGS.valid_X_path
        configs['paths']['validation_y'] = FLAGS.valid_y_path

        ckpt_path_append = 'encoder_{}_loss_{}_alpha_{}_beta_{}_ckpt.h5'.format(experiment[0],
                                                                                experiment[1],
                                                                                experiment[2],
                                                                                1. - literal_eval(experiment[2]))
        csv_path_append = 'encoder_{}_loss_{}_alpha_{}_beta_{}_history.csv'.format(experiment[0],
                                                                                   experiment[1],
                                                                                   experiment[2],
                                                                                   1. - literal_eval(experiment[2]))

        configs['save_configurations']['save_csv_path'] = os.path.join(FLAGS.save_csv_path, csv_path_append)
        configs['save_configurations']['save_checkpoints_path'] = os.path.join(FLAGS.save_ckpt_path, ckpt_path_append)

        layers = configs['layers']['serial_layer_list']
        input = layers[0]
        input_parts = input.split(':')
        if experiment[0] == 'UNet3D':
            input_parts[-1] = '({}, {}, {}, {})'.format(FLAGS.height, FLAGS.width, FLAGS.slices, FLAGS.channels)
        else:
            input_parts[-1] = '({}, {}, {})'.format(FLAGS.height, FLAGS.width, FLAGS.channels)
        input = ':'.join(input_parts)
        configs['config_file']['input_shape'] = '({}, {}, {})'.format(FLAGS.height, FLAGS.width, FLAGS.channels)

        if experiment[0] == 'UNet' or experiment[0] == 'UNet3D':
            layers[0] = input
            last_conv = layers[-3]
            last_conv_parts = last_conv.split(':')
            last_conv_parts[1] = '{}'.format(FLAGS.classes)
            last_conv = ':'.join(last_conv_parts)
            layers[-3] = last_conv
            if FLAGS.use_skip_connections:
                pass
            else:
                print('-------------------------------------------')
                print('removing skip connections')
                print('-------------------------------------------')
                while 'Outer skip target:concatenate' in layers:
                    layers.remove('Outer skip target:concatenate')
            configs['layers']['serial_layer_list'] = layers
        else:
            encoder = layers[1]
            decoder = layers[2:]
            last_conv = decoder[-3]
            last_conv_parts = last_conv.split(':')
            last_conv_parts[1] = '{}'.format(FLAGS.classes)
            last_conv = ':'.join(last_conv_parts)
            decoder[-3] = last_conv
            encoder_parts = encoder.split(':')
            encoder_parts[0] = experiment[0]
            encoder_parts[3] = '({}, {}, {})'.format(FLAGS.height, FLAGS.width, FLAGS.channels)
            if FLAGS.use_skip_connections:
                encoder_parts[-2] = 'True'
            else:
                print('-------------------------------------------')
                print('removing skip connections')
                print('-------------------------------------------')
                encoder_parts[-2] = 'False'
                while 'Outer skip target:concatenate' in decoder:
                    decoder.remove('Outer skip target:concatenate')
            encoder = ':'.join(encoder_parts)

            if FLAGS.use_imagenet_weights:
                print('-------------------------------------------')
                print('using ImageNet weights')
                print('-------------------------------------------')
                encoder_parts[1] = 'False'
                encoder_parts[2] = 'imagenet'
                print(encoder_parts)
                #configs['preprocessing']['repeat_X_switch'] = 'True'
                #configs['preprocessing']['repeat_X_quantity'] = '3'
            else:
                encoder_parts[1] = 'False'
                encoder_parts[2] = 'none'
                configs['preprocessing']['repeat_X_switch'] = 'False'

            # inject some layers to connect the encoder to the decoder
            # for this experiment not many injectors are needed, but may be for others
            # they're broken up by the type of encoder
            if experiment[0] == 'VGG16' or experiment[0] == 'VGG19':
                pass
            elif experiment[0] == 'DenseNet121' or experiment[0] == 'DenseNet169' or experiment[0] == 'DenseNet201':
                pass
            elif experiment[0] == 'InceptionResNetV2' or experiment[0] == 'InceptionV3':
                decoder.insert(0, 'Zero padding 2D:((1, 1), (1, 1))')
            elif experiment[0] == 'MobileNet' or experiment[0] == 'MobileNetV2':
                pass
            elif experiment[0] == 'ResNet50' or experiment[0] == 'ResNet101' or experiment[0] == 'ResNet152':
                pass
            elif experiment[0] == 'ResNet50V2' or experiment[0] == 'ResNet101V2' or experiment[0] == 'ResNet152V2':
                pass
            elif experiment[0] == 'ResNeXt50' or experiment[0] == 'ResNeXt101':
                pass
            elif experiment[0] == 'Xception':
                pass
            else:
                print('Invalid encoder type:', experiment[0])
                do_train = False

            layers = []
            layers.extend([input, encoder])
            if FLAGS.dropout: layers.append('Dropout:0.5')
            for layer in decoder: layers.append(layer)
            configs['layers']['serial_layer_list'] = layers

        # ensure xentropy/jaccard/focal only used once per encoder
        if experiment[1] == 'sparse_categorical_crossentropy'\
                or experiment[1] == 'categorical_crossentropy'\
                or experiment[1] == 'weighted_categorical_crossentropy'\
                or experiment[1] == 'jaccard'\
                or experiment[1] == 'focal':
            if experiment[1] == 'focal':
                configs['loss_function']['parameter1'] = '0.75'
                configs['loss_function']['parameter2'] = '2.0'
            if experiment[1] == 'jaccard':
                configs['loss_function']['parameter1'] = '100.0'
            if experiment[2] == '0.3':
                configs['loss_function']['loss'] = experiment[1]
            else:
                do_train = False
        elif experiment[1] == 'tversky':
            configs['loss_function']['loss'] = experiment[1]
            configs['loss_function']['parameter1'] = experiment[2]
            configs['loss_function']['parameter2'] = str(1. - literal_eval(experiment[2]))
        else:
            do_train = False

        if do_train is True:
            configs_lvl1, errors_lvl1, warnings_lvl1 = level_one_error_checking(configs)

            if any(warnings_lvl1):
                with open('errors.txt', 'a') as f:
                    for warning in warnings_lvl1:
                        f.write("%s\n" % warning)
                    f.close()
                    print('Level 1 warnings encountered.')
                    print("The following level 1 warnings were identified and corrected based on engine defaults:")
                    for warning in warnings_lvl1:
                        print(warning)

            if any(errors_lvl1):
                print('Level 1 errors encountered.')
                print("Please fix the level 1 errors below before continuing:")
                for error in errors_lvl1:
                    print(error)
            else:
                configs_lvl2, errors_lvl2, warnings_lvl2 = level_two_error_checking(configs_lvl1)

                if any(warnings_lvl2):
                    print('Level 2 warnings encountered.')
                    print("The following level 2 warnings were identified and corrected based on engine defaults:")
                    for warning in warnings_lvl2:
                        print(warning)

                if any(errors_lvl2):
                    print('Level 2 errors encountered.')
                    print("Please fix the level 2 errors below before continuing:")
                    for error in errors_lvl2:
                        print(error)
                else:
                    engine = Dlae(configs)
                    engine.run()
                    if any(engine.errors):
                        print('Level 3 errors encountered.')
                        print("Please fix the level 3 errors below before continuing:")
                        for error in engine.errors:
                            print(error)
        else:
            pass
Ejemplo n.º 6
0
def main(FLAGS):
    # set GPU device to use
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'

    # define the experiments
    encoders = FLAGS.encoders.split(',')
    losses = FLAGS.losses.split(',')
    alpha = FLAGS.loss_param1.split(',')
    experiments = [encoders, losses, alpha]
    print(experiments)

    # get data files and sort them
    image_files = os.listdir(FLAGS.test_X_dir)
    anno_files = os.listdir(FLAGS.test_y_dir)
    image_files.sort()
    anno_files.sort()

    for experiment in itertools.product(*experiments):
        # switch to activate training session
        do_eval = True

        # load the base configurations
        if experiment[0] == 'UNet':
            configs = load_config(
                os.path.join(FLAGS.base_configs_dir, 'unet2d.json'))
        elif experiment[0] == 'UNet3D':
            configs = load_config(
                os.path.join(FLAGS.base_configs_dir, 'unet3d.json'))
        elif experiment[0] == 'VGG16' or experiment[0] == 'VGG19':
            configs = load_config(
                os.path.join(FLAGS.base_configs_dir, 'vgg16_unet.json'))
        else:
            configs = load_config(
                os.path.join(FLAGS.base_configs_dir, 'xception_unet.json'))

        # set the path to the model (checkpoints are fine for this)
        ckpt_name = 'encoder_{}_loss_{}_alpha_{}_beta_{}_ckpt.h5'.format(
            experiment[0], experiment[1], experiment[2],
            1. - literal_eval(experiment[2]))
        configs['paths']['load_model'] = os.path.join(FLAGS.ckpt_dir,
                                                      ckpt_name)

        # switch to inference
        configs['config_file']['type_signal'] = 'Inference'

        # perform some preprocessing
        configs['preprocessing']['categorical_switch'] = 'True'
        configs['preprocessing']['minimum_image_intensity'] = '0.0'
        configs['preprocessing']['maximum_image_intensity'] = '2048.0'
        configs['preprocessing']['normalization_type'] = '{}'.format(
            FLAGS.normalization)

        # set some other configurations
        configs['training_configurations']['batch_size'] = '{}'.format(
            FLAGS.batch_size)
        configs['config_file']['input_shape'] = '({}, {}, {})'.format(
            FLAGS.height, FLAGS.width, FLAGS.channels)

        # ensure xentropy/jaccard/focal only used once per encoder
        if experiment[1] == 'sparse_categorical_crossentropy'\
                or experiment[1] == 'categorical_crossentropy'\
                or experiment[1] == 'jaccard'\
                or experiment[1] == 'focal':
            if experiment[1] == 'focal':
                configs['loss_function']['parameter1'] = '0.75'
                configs['loss_function']['parameter2'] = '2.0'
            if experiment[1] == 'jaccard':
                configs['loss_function']['parameter1'] = '100.0'
            if experiment[2] == '0.3':
                configs['loss_function']['loss'] = experiment[1]
            else:
                do_eval = False
        elif experiment[1] == 'tversky':
            configs['loss_function']['loss'] = experiment[1]
            configs['loss_function']['parameter1'] = experiment[2]
            configs['loss_function']['parameter2'] = str(
                1. - literal_eval(experiment[2]))
        else:
            do_eval = False

        # create a location to store evaluation metrics
        metrics = np.zeros((len(image_files), FLAGS.classes, 8))
        overall_accuracy = np.zeros((len(image_files), ))

        # create a file writer to store the metrics
        excel_name = '{}_{}_{}_{}_metrics.xlsx'.format(
            experiment[0], experiment[1], experiment[2],
            1. - literal_eval(experiment[2]))
        writer = pd.ExcelWriter(excel_name)

        for i in range(len(image_files)):
            K.clear_session()
            # define path to the test data
            configs['paths']['test_X'] = os.path.join(FLAGS.test_X_dir,
                                                      image_files[i])

            if do_eval is True:
                configs_lvl1, errors_lvl1, warnings_lvl1 = level_one_error_checking(
                    configs)

                if any(warnings_lvl1):
                    with open('errors.txt', 'a') as f:
                        for warning in warnings_lvl1:
                            f.write("%s\n" % warning)
                        f.close()
                        print('Level 1 warnings encountered.')
                        print(
                            "The following level 1 warnings were identified and corrected based on engine defaults:"
                        )
                        for warning in warnings_lvl1:
                            print(warning)

                if any(errors_lvl1):
                    print('Level 1 errors encountered.')
                    print(
                        "Please fix the level 1 errors below before continuing:"
                    )
                    for error in errors_lvl1:
                        print(error)
                else:
                    configs_lvl2, errors_lvl2, warnings_lvl2 = level_two_error_checking(
                        configs_lvl1)

                    if any(warnings_lvl2):
                        print('Level 2 warnings encountered.')
                        print(
                            "The following level 2 warnings were identified and corrected based on engine defaults:"
                        )
                        for warning in warnings_lvl2:
                            print(warning)

                    if any(errors_lvl2):
                        print('Level 2 errors encountered.')
                        print(
                            "Please fix the level 2 errors below before continuing:"
                        )
                        for error in errors_lvl2:
                            print(error)
                    else:
                        engine = Dlae(configs)
                        engine.run()
                        if any(engine.errors):
                            print('Level 3 errors encountered.')
                            print(
                                "Please fix the level 3 errors below before continuing:"
                            )
                            for error in engine.errors:
                                print(error)

                # check if the images and annotations are the correct files
                print(image_files[i], anno_files[i])

                pred_file = glob(
                    os.path.join(FLAGS.predictions_temp_dir, '*.h5'))[0]
                pt_name = image_files[i].split('.')[0]
                new_name_raw = pt_name + '_{}_{}_{}_{}_raw.h5'.format(
                    experiment[0], experiment[1], experiment[2],
                    1. - literal_eval(experiment[2]))
                new_file_raw = os.path.join(FLAGS.predictions_final_dir,
                                            new_name_raw)
                os.rename(pred_file, new_file_raw)

                ref = read_hdf5_multientry(
                    os.path.join(FLAGS.test_y_dir, anno_files[i]))
                ref = np.squeeze(np.asarray(ref))

                preds = read_hdf5(new_file_raw)
                if experiment[0] == 'UNet3D':
                    ref = np.transpose(ref, (1, 2, 0))

                    # stich the image back together first
                    sw = SlidingWindow(ref, [96, 96, 40], [128, 128, 48])
                    preds = sw.stitch_patches(preds, sw.window_corner_coords,
                                              [128, 128, 48], sw.img_shape,
                                              FLAGS.classes)
                    preds = np.argmax(preds, axis=-1)
                else:
                    preds = np.argmax(preds, axis=-1)

                overall_accuracy[i] = skm.accuracy_score(
                    ref.flatten(), preds.flatten())
                for j in range(FLAGS.classes):
                    organ_pred = (preds == j).astype(np.int64)
                    organ_ref = (ref == j).astype(np.int64)
                    if np.sum(organ_pred) == 0 or np.sum(organ_ref) == 0:
                        metrics[i, j, 0] = 0.
                        metrics[i, j, 1] = 0.
                        metrics[i, j, 2] = 1.
                        metrics[i, j, 3] = 0.
                        metrics[i, j, 4] = 0.
                        metrics[i, j, 5] = 0.
                        metrics[i, j, 6] = np.inf
                        metrics[i, j, 7] = np.inf
                    else:
                        metrics[i, j, 0] = jaccard_index(organ_ref, organ_pred)
                        metrics[i, j, 1] = dice_similarity_coefficient(
                            organ_ref, organ_pred)
                        metrics[i, j, 2] = relative_volume_difference(
                            organ_ref, organ_pred)
                        metrics[i, j, 3] = precision(organ_ref, organ_pred)
                        metrics[i, j, 4] = recall(organ_ref, organ_pred)
                        metrics[i, j, 5] = matthews_correlation_coefficient(
                            organ_ref, organ_pred)
                        metrics[i, j, 6] = mpm.hd95(organ_pred, organ_ref)
                        metrics[i, j, 7] = mpm.assd(organ_pred, organ_ref)
                print(overall_accuracy[i])
                print(metrics[i])

            else:
                pass

        if do_eval is True:
            for k in range(metrics.shape[-1]):
                data = pd.DataFrame(
                    metrics[:, :, k],
                    columns=['bg', 'pros', 'eus', 'sv', 'rect', 'blad'])
                data.to_excel(writer, sheet_name=str(k))
            acc = pd.DataFrame(overall_accuracy, columns=['acc'])
            acc.to_excel(writer, sheet_name='acc')
            writer.save()
Ejemplo n.º 7
0
def main(FLAGS):
    configs = load_config(FLAGS.base_configs)
    configs['training_configurations']['batch_size'] = '512'

    c_train_X = 'train_classification_windows'
    c_train_y = 'train_classification_labels'
    c_valid_X = 'valid_classification_windows'
    c_valid_y = 'valid_classification_labels'
    l_train_X = 'train_localization_windows'
    l_train_y = 'train_localization_labels'
    l_valid_X = 'valid_localization_windows'
    l_valid_y = 'valid_localization_labels'

    paddings = FLAGS.padding.split(',')
    window_sizes = FLAGS.window_size.split(',')
    strides = FLAGS.stride.split(',')
    bg2sat_ratios = FLAGS.bg2sat_ratio.split(',')
    experiments = [window_sizes, strides, paddings, bg2sat_ratios]

    for experiment in itertools.product(*experiments):
        if int(experiment[0]) - 2 * int(experiment[2]) >= FLAGS.minimum_center\
                and int(experiment[1]) < int(experiment[0]) - 2 * int(experiment[2]):
            h5_path_append = '_seedNet2satNet_windowsize_{}_stride_{}_padding_{}_ratio_{}_trainfraction_{}.h5'.format(experiment[0], experiment[1], experiment[2], experiment[3], FLAGS.train_fraction)
            csv_path_append = '_seedNet2satNet_windowsize_{}_stride_{}_padding_{}_ratio_{}_trainfraction_{}.csv'.format(experiment[0], experiment[1], experiment[2], experiment[3], FLAGS.train_fraction)

            create_command = 'python C:/Users/jsanders/Desktop/seedNet2satNet/create_annotated_windows.py '\
                           + '--satnet_data_dir={}'.format(FLAGS.satnet_data_dir)\
                           + '--save_data_dir={}'.format(FLAGS.save_data_dir)\
                           + '--train_file_names={}'.format(FLAGS.train_file_names)\
                           + '--valid_file_names={}'.format(FLAGS.valid_file_names)\
                           + '--window_size={}'.format(FLAGS.window_size)\
                           + '--stride={}'.format(FLAGS.stride)\
                           + '--padding={}'.format(FLAGS.padding)\
                           + '--bg2sat_ratio={}'.format(FLAGS.bg2sat_ratio)\
                           + '--format={}'.format(FLAGS.format)

            os.system(create_command)

            configs['paths']['train_X'] = os.path.join(FLAGS.satnet_data_dir, c_train_X + h5_path_append)
            configs['paths']['train_y'] = os.path.join(FLAGS.satnet_data_dir, c_train_y + h5_path_append)
            configs['paths']['validation_X'] = os.path.join(FLAGS.satnet_data_dir, c_valid_X + h5_path_append)
            configs['paths']['validation_y'] = os.path.join(FLAGS.satnet_data_dir, c_valid_y + h5_path_append)
            configs['monitors']['accuracy_switch'] = 'True'
            configs['monitors']['mse_switch'] = 'False'
            configs['save_configurations']['save_model_path'] = os.path.join(FLAGS.save_model_path, 'classification_model' + h5_path_append)
            configs['save_configurations']['save_csv_path'] = os.path.join(FLAGS.save_csv_path, 'classification_csv' + csv_path_append)
            configs['save_configurations']['save_checkpoints_path'] = os.path.join(FLAGS.save_ckpt_path, 'classification_ckpt' + h5_path_append)

            configs_lvl1, errors_lvl1, warnings_lvl1 = level_one_error_checking(configs)

            if any(warnings_lvl1):
                with open('errors.txt', 'a') as f:
                    for warning in warnings_lvl1:
                        f.write("%s\n" % warning)
                    f.close()
                    print('Level 1 warnings encountered.')
                    print("The following level 1 warnings were identified and corrected based on engine defaults:")
                    for warning in warnings_lvl1:
                        print(warning)

            if any(errors_lvl1):
                print('Level 1 errors encountered.')
                print("Please fix the level 1 errors below before continuing:")
                for error in errors_lvl1:
                    print(error)
            else:
                configs_lvl2, errors_lvl2, warnings_lvl2 = level_two_error_checking(configs_lvl1)

                if any(warnings_lvl2):
                    print('Level 2 warnings encountered.')
                    print("The following level 2 warnings were identified and corrected based on engine defaults:")
                    for warning in warnings_lvl2:
                        print(warning)

                if any(errors_lvl2):
                    print('Level 2 errors encountered.')
                    print("Please fix the level 2 errors below before continuing:")
                    for error in errors_lvl2:
                        print(error)
                else:
                    engine = Dlae(configs).run()
                    if any(engine.errors):
                        print('Level 3 errors encountered.')
                        print("Please fix the level 3 errors below before continuing:")
                        for error in engine.errors:
                            print(error)

            os.remove(configs['paths']['train_X'])
            os.remove(configs['paths']['train_y'])
            os.remove(configs['paths']['validation_X'])
            os.remove(configs['paths']['validation_y'])

            configs['paths']['train_X'] = os.path.join(FLAGS.satnet_data_dir, l_train_X + h5_path_append)
            configs['paths']['train_y'] = os.path.join(FLAGS.satnet_data_dir, l_train_y + h5_path_append)
            configs['paths']['validation_X'] = os.path.join(FLAGS.satnet_data_dir, l_valid_X + h5_path_append)
            configs['paths']['validation_y'] = os.path.join(FLAGS.satnet_data_dir, l_valid_y + h5_path_append)
            configs['preprocessing']['categorical_switch'] = 'False'
            configs['preprocessing']['weight_loss_switch'] = 'False'
            configs['loss_function']['loss'] = 'mean_squared_error'
            configs['monitors']['accuracy_switch'] = 'False'
            configs['monitors']['mse_switch'] = 'True'
            configs['save_configurations']['save_model_path'] = os.path.join(FLAGS.save_model_path, 'localization_model' + h5_path_append)
            configs['save_configurations']['save_csv_path'] = os.path.join(FLAGS.save_csv_path, 'localization_csv' + csv_path_append)
            configs['save_configurations']['save_checkpoints_path'] = os.path.join(FLAGS.save_ckpt_path, 'localization_ckpt' + h5_path_append)
            layers = configs['layers']['serial_layer_list']
            layers.pop()
            configs['layers']['serial_layer_list'] = layers

            configs_lvl1, errors_lvl1, warnings_lvl1 = level_one_error_checking(configs)

            if any(warnings_lvl1):
                with open('errors.txt', 'a') as f:
                    for warning in warnings_lvl1:
                        f.write("%s\n" % warning)
                    f.close()
                    print('Level 1 warnings encountered.')
                    print("The following level 1 warnings were identified and corrected based on engine defaults:")
                    for warning in warnings_lvl1:
                        print(warning)

            if any(errors_lvl1):
                print('Level 1 errors encountered.')
                print("Please fix the level 1 errors below before continuing:")
                for error in errors_lvl1:
                    print(error)
            else:
                configs_lvl2, errors_lvl2, warnings_lvl2 = level_two_error_checking(configs_lvl1)

                if any(warnings_lvl2):
                    print('Level 2 warnings encountered.')
                    print("The following level 2 warnings were identified and corrected based on engine defaults:")
                    for warning in warnings_lvl2:
                        print(warning)

                if any(errors_lvl2):
                    print('Level 2 errors encountered.')
                    print("Please fix the level 2 errors below before continuing:")
                    for error in errors_lvl2:
                        print(error)
                else:
                    engine = Dlae(configs).run()
                    if any(engine.errors):
                        print('Level 3 errors encountered.')
                        print("Please fix the level 3 errors below before continuing:")
                        for error in engine.errors:
                            print(error)

            os.remove(configs['paths']['train_X'])
            os.remove(configs['paths']['train_y'])
            os.remove(configs['paths']['validation_X'])
            os.remove(configs['paths']['validation_y'])
Ejemplo n.º 8
0
# See the License for the specific language governing permissions and
# limitations under the License.
"""dlae.py"""

import sys
from src.gui.constructor import DlaeGui
from src.engine.constructor import Dlae
from src.utils.engine_utils import level_one_error_checking
from src.utils.engine_utils import level_two_error_checking
from src.utils.general_utils import load_config
import tkinter as tk

if __name__ == '__main__':
    if len(sys.argv) > 1:
        config_file = sys.argv[1]
        configs = load_config(config_file)
        configs_lvl1, errors_lvl1, warnings_lvl1 = level_one_error_checking(
            configs)

        if any(warnings_lvl1):
            print('Level 1 warnings encountered.')
            print(
                "The following level 1 warnings were identified and corrected based on engine defaults:"
            )
            for warning in warnings_lvl1:
                print(warning)

        if any(errors_lvl1):
            print('Level 1 errors encountered.')
            print("Please fix the level 1 errors below before continuing:")
            for error in errors_lvl1: