Exemplo n.º 1
0
def _write_metadata_one_cnn(model_object, argument_dict):
    """Writes metadata for one CNN to file.

    :param model_object: Untrained CNN (instance of `keras.models.Model` or
        `keras.models.Sequential`).
    :param argument_dict: See doc for `_train_one_cnn`.
    :return: metadata_dict: See doc for `cnn.write_model_metadata`.
    :return: training_option_dict: Same.
    """

    from gewittergefahr.deep_learning import cnn
    from gewittergefahr.deep_learning import input_examples
    from gewittergefahr.deep_learning import \
        training_validation_io as trainval_io
    from gewittergefahr.scripts import deep_learning_helper as dl_helper

    # Read input args.
    sounding_field_names = argument_dict[dl_helper.SOUNDING_FIELDS_ARG_NAME]
    radar_field_name_by_channel = argument_dict[RADAR_FIELDS_KEY]
    layer_op_name_by_channel = argument_dict[LAYER_OPERATIONS_KEY]
    min_height_by_channel_m_agl = argument_dict[MIN_HEIGHTS_KEY]
    max_height_by_channel_m_agl = argument_dict[MAX_HEIGHTS_KEY]

    normalization_type_string = argument_dict[
        dl_helper.NORMALIZATION_TYPE_ARG_NAME]
    normalization_file_name = argument_dict[
        dl_helper.NORMALIZATION_FILE_ARG_NAME]
    min_normalized_value = argument_dict[dl_helper.MIN_NORM_VALUE_ARG_NAME]
    max_normalized_value = argument_dict[dl_helper.MAX_NORM_VALUE_ARG_NAME]

    target_name = argument_dict[dl_helper.TARGET_NAME_ARG_NAME]
    downsampling_classes = numpy.array(
        argument_dict[dl_helper.DOWNSAMPLING_CLASSES_ARG_NAME], dtype=int)
    downsampling_fractions = numpy.array(
        argument_dict[dl_helper.DOWNSAMPLING_FRACTIONS_ARG_NAME], dtype=float)

    monitor_string = argument_dict[dl_helper.MONITOR_ARG_NAME]
    weight_loss_function = bool(argument_dict[dl_helper.WEIGHT_LOSS_ARG_NAME])

    x_translations_pixels = numpy.array(
        argument_dict[dl_helper.X_TRANSLATIONS_ARG_NAME], dtype=int)
    y_translations_pixels = numpy.array(
        argument_dict[dl_helper.Y_TRANSLATIONS_ARG_NAME], dtype=int)
    ccw_rotation_angles_deg = numpy.array(
        argument_dict[dl_helper.ROTATION_ANGLES_ARG_NAME], dtype=float)
    noise_standard_deviation = argument_dict[dl_helper.NOISE_STDEV_ARG_NAME]
    num_noisings = argument_dict[dl_helper.NUM_NOISINGS_ARG_NAME]
    flip_in_x = bool(argument_dict[dl_helper.FLIP_X_ARG_NAME])
    flip_in_y = bool(argument_dict[dl_helper.FLIP_Y_ARG_NAME])

    top_training_dir_name = argument_dict[dl_helper.TRAINING_DIR_ARG_NAME]
    first_training_time_string = argument_dict[
        dl_helper.FIRST_TRAINING_TIME_ARG_NAME]
    last_training_time_string = argument_dict[
        dl_helper.LAST_TRAINING_TIME_ARG_NAME]
    num_examples_per_train_batch = argument_dict[
        dl_helper.NUM_EX_PER_TRAIN_ARG_NAME]

    top_validation_dir_name = argument_dict[dl_helper.VALIDATION_DIR_ARG_NAME]
    first_validation_time_string = argument_dict[
        dl_helper.FIRST_VALIDATION_TIME_ARG_NAME]
    last_validation_time_string = argument_dict[
        dl_helper.LAST_VALIDATION_TIME_ARG_NAME]
    num_examples_per_validn_batch = argument_dict[
        dl_helper.NUM_EX_PER_VALIDN_ARG_NAME]

    num_epochs = argument_dict[dl_helper.NUM_EPOCHS_ARG_NAME]
    num_training_batches_per_epoch = argument_dict[
        dl_helper.NUM_TRAINING_BATCHES_ARG_NAME]
    num_validation_batches_per_epoch = argument_dict[
        dl_helper.NUM_VALIDATION_BATCHES_ARG_NAME]
    output_dir_name = argument_dict[dl_helper.OUTPUT_DIR_ARG_NAME]

    # Process input args.
    first_training_time_unix_sec = time_conversion.string_to_unix_sec(
        first_training_time_string, TIME_FORMAT)
    last_training_time_unix_sec = time_conversion.string_to_unix_sec(
        last_training_time_string, TIME_FORMAT)

    first_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        first_validation_time_string, TIME_FORMAT)
    last_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        last_validation_time_string, TIME_FORMAT)

    if sounding_field_names[0] in ['', 'None']:
        sounding_field_names = None

    num_channels = len(radar_field_name_by_channel)
    layer_operation_dicts = [{}] * num_channels

    for k in range(num_channels):
        layer_operation_dicts[k] = {
            input_examples.RADAR_FIELD_KEY: radar_field_name_by_channel[k],
            input_examples.OPERATION_NAME_KEY: layer_op_name_by_channel[k],
            input_examples.MIN_HEIGHT_KEY: min_height_by_channel_m_agl[k],
            input_examples.MAX_HEIGHT_KEY: max_height_by_channel_m_agl[k]
        }

    if len(downsampling_classes) > 1:
        downsampling_dict = dict(
            list(zip(downsampling_classes, downsampling_fractions)))
    else:
        downsampling_dict = None

    translate_flag = (len(x_translations_pixels) > 1
                      or x_translations_pixels[0] != 0
                      or y_translations_pixels[0] != 0)

    if not translate_flag:
        x_translations_pixels = None
        y_translations_pixels = None

    if len(ccw_rotation_angles_deg) == 1 and ccw_rotation_angles_deg[0] == 0:
        ccw_rotation_angles_deg = None

    if num_noisings <= 0:
        num_noisings = 0
        noise_standard_deviation = None

    # Find training and validation files.
    training_file_names = input_examples.find_many_example_files(
        top_directory_name=top_training_dir_name,
        shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER,
        raise_error_if_any_missing=False)

    validation_file_names = input_examples.find_many_example_files(
        top_directory_name=top_validation_dir_name,
        shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER,
        raise_error_if_any_missing=False)

    # Write metadata.
    metadata_dict = {
        cnn.NUM_EPOCHS_KEY: num_epochs,
        cnn.NUM_TRAINING_BATCHES_KEY: num_training_batches_per_epoch,
        cnn.NUM_VALIDATION_BATCHES_KEY: num_validation_batches_per_epoch,
        cnn.MONITOR_STRING_KEY: monitor_string,
        cnn.WEIGHT_LOSS_FUNCTION_KEY: weight_loss_function,
        cnn.CONV_2D3D_KEY: False,
        cnn.VALIDATION_FILES_KEY: validation_file_names,
        cnn.FIRST_VALIDN_TIME_KEY: first_validation_time_unix_sec,
        cnn.LAST_VALIDN_TIME_KEY: last_validation_time_unix_sec,
        cnn.LAYER_OPERATIONS_KEY: layer_operation_dicts,
        cnn.NUM_EX_PER_VALIDN_BATCH_KEY: num_examples_per_validn_batch
    }

    input_tensor = model_object.input
    if isinstance(input_tensor, list):
        input_tensor = input_tensor[0]

    num_grid_rows = input_tensor.get_shape().as_list()[1]
    num_grid_columns = input_tensor.get_shape().as_list()[2]

    training_option_dict = {
        trainval_io.EXAMPLE_FILES_KEY: training_file_names,
        trainval_io.TARGET_NAME_KEY: target_name,
        trainval_io.FIRST_STORM_TIME_KEY: first_training_time_unix_sec,
        trainval_io.LAST_STORM_TIME_KEY: last_training_time_unix_sec,
        trainval_io.NUM_EXAMPLES_PER_BATCH_KEY: num_examples_per_train_batch,
        trainval_io.SOUNDING_FIELDS_KEY: sounding_field_names,
        trainval_io.SOUNDING_HEIGHTS_KEY: SOUNDING_HEIGHTS_M_AGL,
        trainval_io.NUM_ROWS_KEY: num_grid_rows,
        trainval_io.NUM_COLUMNS_KEY: num_grid_columns,
        trainval_io.NORMALIZATION_TYPE_KEY: normalization_type_string,
        trainval_io.NORMALIZATION_FILE_KEY: normalization_file_name,
        trainval_io.MIN_NORMALIZED_VALUE_KEY: min_normalized_value,
        trainval_io.MAX_NORMALIZED_VALUE_KEY: max_normalized_value,
        trainval_io.BINARIZE_TARGET_KEY: False,
        trainval_io.SAMPLING_FRACTIONS_KEY: downsampling_dict,
        trainval_io.LOOP_ONCE_KEY: False,
        trainval_io.X_TRANSLATIONS_KEY: x_translations_pixels,
        trainval_io.Y_TRANSLATIONS_KEY: y_translations_pixels,
        trainval_io.ROTATION_ANGLES_KEY: ccw_rotation_angles_deg,
        trainval_io.NOISE_STDEV_KEY: noise_standard_deviation,
        trainval_io.NUM_NOISINGS_KEY: num_noisings,
        trainval_io.FLIP_X_KEY: flip_in_x,
        trainval_io.FLIP_Y_KEY: flip_in_y
    }

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)
    metafile_name = '{0:s}/model_metadata.p'.format(output_dir_name)

    print('Writing metadata to: "{0:s}"...'.format(metafile_name))
    cnn.write_model_metadata(pickle_file_name=metafile_name,
                             metadata_dict=metadata_dict,
                             training_option_dict=training_option_dict)

    return metadata_dict, training_option_dict
def _run(input_model_file_name, sounding_field_names,
         normalization_type_string, normalization_param_file_name,
         min_normalized_value, max_normalized_value, target_name,
         downsampling_classes, downsampling_fractions, monitor_string,
         weight_loss_function, x_translations_pixels, y_translations_pixels,
         ccw_rotation_angles_deg, noise_standard_deviation, num_noisings,
         flip_in_x, flip_in_y, top_training_dir_name,
         first_training_time_string, last_training_time_string,
         num_examples_per_train_batch, top_validation_dir_name,
         first_validation_time_string, last_validation_time_string,
         num_examples_per_validn_batch, num_epochs,
         num_training_batches_per_epoch, num_validation_batches_per_epoch,
         output_dir_name):
    """Trains CNN with 2-D and 3-D MYRORSS images.

    This is effectively the main method.

    :param input_model_file_name: See documentation at top of file.
    :param sounding_field_names: Same.
    :param normalization_type_string: Same.
    :param normalization_param_file_name: Same.
    :param min_normalized_value: Same.
    :param max_normalized_value: Same.
    :param target_name: Same.
    :param downsampling_classes: Same.
    :param downsampling_fractions: Same.
    :param monitor_string: Same.
    :param weight_loss_function: Same.
    :param x_translations_pixels: Same.
    :param y_translations_pixels: Same.
    :param ccw_rotation_angles_deg: Same.
    :param noise_standard_deviation: Same.
    :param num_noisings: Same.
    :param flip_in_x: Same.
    :param flip_in_y: Same.
    :param top_training_dir_name: Same.
    :param first_training_time_string: Same.
    :param last_training_time_string: Same.
    :param num_examples_per_train_batch: Same.
    :param top_validation_dir_name: Same.
    :param first_validation_time_string: Same.
    :param last_validation_time_string: Same.
    :param num_examples_per_validn_batch: Same.
    :param num_epochs: Same.
    :param num_training_batches_per_epoch: Same.
    :param num_validation_batches_per_epoch: Same.
    :param output_dir_name: Same.
    """

    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    # argument_file_name = '{0:s}/input_args.p'.format(output_dir_name)
    # print('Writing input args to: "{0:s}"...'.format(argument_file_name))
    #
    # argument_file_handle = open(argument_file_name, 'wb')
    # pickle.dump(INPUT_ARG_OBJECT.__dict__, argument_file_handle)
    # argument_file_handle.close()
    #
    # return

    # Process input args.
    first_training_time_unix_sec = time_conversion.string_to_unix_sec(
        first_training_time_string, TIME_FORMAT)
    last_training_time_unix_sec = time_conversion.string_to_unix_sec(
        last_training_time_string, TIME_FORMAT)

    first_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        first_validation_time_string, TIME_FORMAT)
    last_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        last_validation_time_string, TIME_FORMAT)

    if sounding_field_names[0] in ['', 'None']:
        sounding_field_names = None

    if len(downsampling_classes) > 1:
        downsampling_dict = dict(
            list(zip(downsampling_classes, downsampling_fractions)))
    else:
        downsampling_dict = None

    if (len(x_translations_pixels) == 1
            and x_translations_pixels + y_translations_pixels == 0):
        x_translations_pixels = None
        y_translations_pixels = None

    if len(ccw_rotation_angles_deg) == 1 and ccw_rotation_angles_deg[0] == 0:
        ccw_rotation_angles_deg = None

    if num_noisings <= 0:
        num_noisings = 0
        noise_standard_deviation = None

    # Set output locations.
    output_model_file_name = '{0:s}/model.h5'.format(output_dir_name)
    history_file_name = '{0:s}/model_history.csv'.format(output_dir_name)
    tensorboard_dir_name = '{0:s}/tensorboard'.format(output_dir_name)
    model_metafile_name = '{0:s}/model_metadata.p'.format(output_dir_name)

    # Find training and validation files.
    training_file_names = input_examples.find_many_example_files(
        top_directory_name=top_training_dir_name,
        shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER,
        raise_error_if_any_missing=False)

    validation_file_names = input_examples.find_many_example_files(
        top_directory_name=top_validation_dir_name,
        shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER,
        raise_error_if_any_missing=False)

    # Read architecture.
    print(
        'Reading architecture from: "{0:s}"...'.format(input_model_file_name))
    model_object = cnn.read_model(input_model_file_name)
    # model_object = keras.models.clone_model(model_object)

    # TODO(thunderhoser): This is a HACK.
    model_object.compile(loss=keras.losses.binary_crossentropy,
                         optimizer=keras.optimizers.Adam(),
                         metrics=cnn_setup.DEFAULT_METRIC_FUNCTION_LIST)

    print(SEPARATOR_STRING)
    model_object.summary()
    print(SEPARATOR_STRING)

    # Write metadata.
    metadata_dict = {
        cnn.NUM_EPOCHS_KEY: num_epochs,
        cnn.NUM_TRAINING_BATCHES_KEY: num_training_batches_per_epoch,
        cnn.NUM_VALIDATION_BATCHES_KEY: num_validation_batches_per_epoch,
        cnn.MONITOR_STRING_KEY: monitor_string,
        cnn.WEIGHT_LOSS_FUNCTION_KEY: weight_loss_function,
        cnn.CONV_2D3D_KEY: True,
        cnn.VALIDATION_FILES_KEY: validation_file_names,
        cnn.FIRST_VALIDN_TIME_KEY: first_validation_time_unix_sec,
        cnn.LAST_VALIDN_TIME_KEY: last_validation_time_unix_sec,
        cnn.NUM_EX_PER_VALIDN_BATCH_KEY: num_examples_per_validn_batch
    }

    if isinstance(model_object.input, list):
        list_of_input_tensors = model_object.input
    else:
        list_of_input_tensors = [model_object.input]

    upsample_refl = len(list_of_input_tensors) == 2
    num_grid_rows = list_of_input_tensors[0].get_shape().as_list()[1]
    num_grid_columns = list_of_input_tensors[0].get_shape().as_list()[2]

    if upsample_refl:
        num_grid_rows = int(numpy.round(num_grid_rows / 2))
        num_grid_columns = int(numpy.round(num_grid_columns / 2))

    training_option_dict = {
        trainval_io.EXAMPLE_FILES_KEY: training_file_names,
        trainval_io.TARGET_NAME_KEY: target_name,
        trainval_io.FIRST_STORM_TIME_KEY: first_training_time_unix_sec,
        trainval_io.LAST_STORM_TIME_KEY: last_training_time_unix_sec,
        trainval_io.NUM_EXAMPLES_PER_BATCH_KEY: num_examples_per_train_batch,
        trainval_io.RADAR_FIELDS_KEY:
        input_examples.AZIMUTHAL_SHEAR_FIELD_NAMES,
        trainval_io.RADAR_HEIGHTS_KEY: REFLECTIVITY_HEIGHTS_M_AGL,
        trainval_io.SOUNDING_FIELDS_KEY: sounding_field_names,
        trainval_io.SOUNDING_HEIGHTS_KEY: SOUNDING_HEIGHTS_M_AGL,
        trainval_io.NUM_ROWS_KEY: num_grid_rows,
        trainval_io.NUM_COLUMNS_KEY: num_grid_columns,
        trainval_io.NORMALIZATION_TYPE_KEY: normalization_type_string,
        trainval_io.NORMALIZATION_FILE_KEY: normalization_param_file_name,
        trainval_io.MIN_NORMALIZED_VALUE_KEY: min_normalized_value,
        trainval_io.MAX_NORMALIZED_VALUE_KEY: max_normalized_value,
        trainval_io.BINARIZE_TARGET_KEY: False,
        trainval_io.SAMPLING_FRACTIONS_KEY: downsampling_dict,
        trainval_io.LOOP_ONCE_KEY: False,
        trainval_io.X_TRANSLATIONS_KEY: x_translations_pixels,
        trainval_io.Y_TRANSLATIONS_KEY: y_translations_pixels,
        trainval_io.ROTATION_ANGLES_KEY: ccw_rotation_angles_deg,
        trainval_io.NOISE_STDEV_KEY: noise_standard_deviation,
        trainval_io.NUM_NOISINGS_KEY: num_noisings,
        trainval_io.FLIP_X_KEY: flip_in_x,
        trainval_io.FLIP_Y_KEY: flip_in_y,
        trainval_io.UPSAMPLE_REFLECTIVITY_KEY: upsample_refl
    }

    print('Writing metadata to: "{0:s}"...'.format(model_metafile_name))
    cnn.write_model_metadata(pickle_file_name=model_metafile_name,
                             metadata_dict=metadata_dict,
                             training_option_dict=training_option_dict)

    cnn.train_cnn_2d3d_myrorss(
        model_object=model_object,
        model_file_name=output_model_file_name,
        history_file_name=history_file_name,
        tensorboard_dir_name=tensorboard_dir_name,
        num_epochs=num_epochs,
        num_training_batches_per_epoch=num_training_batches_per_epoch,
        training_option_dict=training_option_dict,
        monitor_string=monitor_string,
        weight_loss_function=weight_loss_function,
        num_validation_batches_per_epoch=num_validation_batches_per_epoch,
        validation_file_names=validation_file_names,
        first_validn_time_unix_sec=first_validation_time_unix_sec,
        last_validn_time_unix_sec=last_validation_time_unix_sec,
        num_examples_per_validn_batch=num_examples_per_validn_batch)
Exemplo n.º 3
0
def _run(input_model_file_name, sounding_field_names, normalization_type_string,
         normalization_param_file_name, min_normalized_value,
         max_normalized_value, target_name, downsampling_classes,
         downsampling_fractions, monitor_string, weight_loss_function,
         top_training_dir_name, first_training_time_string,
         last_training_time_string, num_examples_per_train_batch,
         top_validation_dir_name, first_validation_time_string,
         last_validation_time_string, num_examples_per_validn_batch, num_epochs,
         num_training_batches_per_epoch, num_validation_batches_per_epoch,
         output_dir_name):
    """Trains CNN with soundings only.

    This is effectively the main method.

    :param input_model_file_name: See documentation at top of file.
    :param sounding_field_names: Same.
    :param normalization_type_string: Same.
    :param normalization_param_file_name: Same.
    :param min_normalized_value: Same.
    :param max_normalized_value: Same.
    :param target_name: Same.
    :param downsampling_classes: Same.
    :param downsampling_fractions: Same.
    :param monitor_string: Same.
    :param weight_loss_function: Same.
    :param top_training_dir_name: Same.
    :param first_training_time_string: Same.
    :param last_training_time_string: Same.
    :param num_examples_per_train_batch: Same.
    :param top_validation_dir_name: Same.
    :param first_validation_time_string: Same.
    :param last_validation_time_string: Same.
    :param num_examples_per_validn_batch: Same.
    :param num_epochs: Same.
    :param num_training_batches_per_epoch: Same.
    :param num_validation_batches_per_epoch: Same.
    :param output_dir_name: Same.
    """

    # Process input args.
    first_training_time_unix_sec = time_conversion.string_to_unix_sec(
        first_training_time_string, TIME_FORMAT)
    last_training_time_unix_sec = time_conversion.string_to_unix_sec(
        last_training_time_string, TIME_FORMAT)

    first_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        first_validation_time_string, TIME_FORMAT)
    last_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        last_validation_time_string, TIME_FORMAT)

    if len(downsampling_classes) > 1:
        downsampling_dict = dict(list(zip(
            downsampling_classes, downsampling_fractions
        )))
    else:
        downsampling_dict = None

    # Set output locations.
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    output_model_file_name = '{0:s}/model.h5'.format(output_dir_name)
    history_file_name = '{0:s}/model_history.csv'.format(output_dir_name)
    tensorboard_dir_name = '{0:s}/tensorboard'.format(output_dir_name)
    model_metafile_name = '{0:s}/model_metadata.p'.format(output_dir_name)

    # Find training and validation files.
    training_file_names = input_examples.find_many_example_files(
        top_directory_name=top_training_dir_name, shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER, raise_error_if_any_missing=False)

    validation_file_names = input_examples.find_many_example_files(
        top_directory_name=top_validation_dir_name, shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER, raise_error_if_any_missing=False)

    # Read architecture.
    print('Reading architecture from: "{0:s}"...'.format(input_model_file_name))
    model_object = cnn.read_model(input_model_file_name)
    # model_object = keras.models.clone_model(model_object)

    # TODO(thunderhoser): This is a HACK.
    model_object.compile(
        loss=keras.losses.binary_crossentropy,
        optimizer=keras.optimizers.Adam(),
        metrics=cnn_setup.DEFAULT_METRIC_FUNCTION_LIST)

    print(SEPARATOR_STRING)
    model_object.summary()
    print(SEPARATOR_STRING)

    # Write metadata.
    metadata_dict = {
        cnn.NUM_EPOCHS_KEY: num_epochs,
        cnn.NUM_TRAINING_BATCHES_KEY: num_training_batches_per_epoch,
        cnn.NUM_VALIDATION_BATCHES_KEY: num_validation_batches_per_epoch,
        cnn.MONITOR_STRING_KEY: monitor_string,
        cnn.WEIGHT_LOSS_FUNCTION_KEY: weight_loss_function,
        cnn.CONV_2D3D_KEY: False,
        cnn.VALIDATION_FILES_KEY: validation_file_names,
        cnn.FIRST_VALIDN_TIME_KEY: first_validation_time_unix_sec,
        cnn.LAST_VALIDN_TIME_KEY: last_validation_time_unix_sec,
        cnn.NUM_EX_PER_VALIDN_BATCH_KEY: num_examples_per_validn_batch
    }

    training_option_dict = {
        trainval_io.EXAMPLE_FILES_KEY: training_file_names,
        trainval_io.NUM_EXAMPLES_PER_BATCH_KEY: num_examples_per_train_batch,
        trainval_io.FIRST_STORM_TIME_KEY: first_training_time_unix_sec,
        trainval_io.LAST_STORM_TIME_KEY: last_training_time_unix_sec,
        trainval_io.SOUNDING_FIELDS_KEY: sounding_field_names,
        trainval_io.SOUNDING_HEIGHTS_KEY: SOUNDING_HEIGHTS_M_AGL,
        trainval_io.NORMALIZATION_TYPE_KEY: normalization_type_string,
        trainval_io.NORMALIZATION_FILE_KEY: normalization_param_file_name,
        trainval_io.MIN_NORMALIZED_VALUE_KEY: min_normalized_value,
        trainval_io.MAX_NORMALIZED_VALUE_KEY: max_normalized_value,
        trainval_io.TARGET_NAME_KEY: target_name,
        trainval_io.BINARIZE_TARGET_KEY: False,
        trainval_io.SAMPLING_FRACTIONS_KEY: downsampling_dict,
        trainval_io.LOOP_ONCE_KEY: False
    }

    print('Writing metadata to: "{0:s}"...'.format(model_metafile_name))
    cnn.write_model_metadata(
        pickle_file_name=model_metafile_name, metadata_dict=metadata_dict,
        training_option_dict=training_option_dict)

    cnn.train_cnn_with_soundings(
        model_object=model_object, model_file_name=output_model_file_name,
        history_file_name=history_file_name,
        tensorboard_dir_name=tensorboard_dir_name, num_epochs=num_epochs,
        num_training_batches_per_epoch=num_training_batches_per_epoch,
        training_option_dict=training_option_dict,
        monitor_string=monitor_string,
        weight_loss_function=weight_loss_function,
        num_validation_batches_per_epoch=num_validation_batches_per_epoch,
        validation_file_names=validation_file_names,
        first_validn_time_unix_sec=first_validation_time_unix_sec,
        last_validn_time_unix_sec=last_validation_time_unix_sec,
        num_examples_per_validn_batch=num_examples_per_validn_batch)
def _run(input_model_file_name, radar_field_names, sounding_field_names,
         normalization_type_string, normalization_param_file_name,
         min_normalized_value, max_normalized_value, downsampling_keys,
         downsampling_fractions, monitor_string, weight_loss_function,
         refl_masking_threshold_dbz, x_translations_pixels,
         y_translations_pixels, ccw_rotation_angles_deg,
         noise_standard_deviation, num_noisings, flip_in_x, flip_in_y,
         top_training_dir_name, first_training_time_string,
         last_training_time_string, top_validation_dir_name,
         first_validation_time_string, last_validation_time_string,
         num_examples_per_batch, num_epochs, num_training_batches_per_epoch,
         num_validation_batches_per_epoch, output_dir_name):
    """Trains CNN with native (3-D) GridRad images.

    This is effectively the main method.

    :param input_model_file_name: See documentation at top of file.
    :param radar_field_names: Same.
    :param sounding_field_names: Same.
    :param normalization_type_string: Same.
    :param normalization_param_file_name: Same.
    :param min_normalized_value: Same.
    :param max_normalized_value: Same.
    :param downsampling_keys: Same.
    :param downsampling_fractions: Same.
    :param monitor_string: Same.
    :param weight_loss_function: Same.
    :param refl_masking_threshold_dbz: Same.
    :param x_translations_pixels: Same.
    :param y_translations_pixels: Same.
    :param ccw_rotation_angles_deg: Same.
    :param noise_standard_deviation: Same.
    :param num_noisings: Same.
    :param flip_in_x: Same.
    :param flip_in_y: Same.
    :param top_training_dir_name: Same.
    :param first_training_time_string: Same.
    :param last_training_time_string: Same.
    :param top_validation_dir_name: Same.
    :param first_validation_time_string: Same.
    :param last_validation_time_string: Same.
    :param num_examples_per_batch: Same.
    :param num_epochs: Same.
    :param num_training_batches_per_epoch: Same.
    :param num_validation_batches_per_epoch: Same.
    :param output_dir_name: Same.
    """

    # Process input args.
    first_training_time_unix_sec = time_conversion.string_to_unix_sec(
        first_training_time_string, TIME_FORMAT)
    last_training_time_unix_sec = time_conversion.string_to_unix_sec(
        last_training_time_string, TIME_FORMAT)

    first_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        first_validation_time_string, TIME_FORMAT)
    last_validation_time_unix_sec = time_conversion.string_to_unix_sec(
        last_validation_time_string, TIME_FORMAT)

    if sounding_field_names[0] in ['', 'None']:
        sounding_field_names = None

    if len(downsampling_keys) > 1:
        class_to_sampling_fraction_dict = dict(
            zip(downsampling_keys, downsampling_fractions))
    else:
        class_to_sampling_fraction_dict = None

    if (len(x_translations_pixels) == 1
            and x_translations_pixels + y_translations_pixels == 0):
        x_translations_pixels = None
        y_translations_pixels = None

    if len(ccw_rotation_angles_deg) == 1 and ccw_rotation_angles_deg[0] == 0:
        ccw_rotation_angles_deg = None

    if num_noisings <= 0:
        num_noisings = 0
        noise_standard_deviation = None

    # Set output locations.
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    output_model_file_name = '{0:s}/model.h5'.format(output_dir_name)
    history_file_name = '{0:s}/model_history.csv'.format(output_dir_name)
    tensorboard_dir_name = '{0:s}/tensorboard'.format(output_dir_name)
    model_metafile_name = '{0:s}/model_metadata.p'.format(output_dir_name)

    # Find training and validation files.
    training_file_names = input_examples.find_many_example_files(
        top_directory_name=top_training_dir_name,
        shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER,
        raise_error_if_any_missing=False)

    validation_file_names = input_examples.find_many_example_files(
        top_directory_name=top_validation_dir_name,
        shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER,
        raise_error_if_any_missing=False)

    # Read architecture.
    print 'Reading architecture from: "{0:s}"...'.format(input_model_file_name)
    model_object = cnn.read_model(input_model_file_name)
    model_object = keras.models.clone_model(model_object)

    # TODO(thunderhoser): This is a HACK.
    model_object.compile(loss=keras.losses.binary_crossentropy,
                         optimizer=keras.optimizers.Adam(),
                         metrics=cnn_setup.DEFAULT_METRIC_FUNCTION_LIST)

    print SEPARATOR_STRING
    model_object.summary()
    print SEPARATOR_STRING

    # Write metadata.
    this_example_dict = input_examples.read_example_file(
        netcdf_file_name=training_file_names[0], metadata_only=True)
    target_name = this_example_dict[input_examples.TARGET_NAME_KEY]

    metadata_dict = {
        cnn.TARGET_NAME_KEY: target_name,
        cnn.NUM_EPOCHS_KEY: num_epochs,
        cnn.NUM_TRAINING_BATCHES_KEY: num_training_batches_per_epoch,
        cnn.NUM_VALIDATION_BATCHES_KEY: num_validation_batches_per_epoch,
        cnn.MONITOR_STRING_KEY: monitor_string,
        cnn.WEIGHT_LOSS_FUNCTION_KEY: weight_loss_function,
        cnn.USE_2D3D_CONVOLUTION_KEY: False,
        cnn.VALIDATION_FILES_KEY: validation_file_names,
        cnn.FIRST_VALIDN_TIME_KEY: first_validation_time_unix_sec,
        cnn.LAST_VALIDN_TIME_KEY: last_validation_time_unix_sec
    }

    input_tensor = model_object.input
    if isinstance(input_tensor, list):
        input_tensor = input_tensor[0]

    num_grid_rows = input_tensor.get_shape().as_list()[1]
    num_grid_columns = input_tensor.get_shape().as_list()[2]

    training_option_dict = {
        trainval_io.EXAMPLE_FILES_KEY: training_file_names,
        trainval_io.FIRST_STORM_TIME_KEY: first_training_time_unix_sec,
        trainval_io.LAST_STORM_TIME_KEY: last_training_time_unix_sec,
        trainval_io.NUM_EXAMPLES_PER_BATCH_KEY: num_examples_per_batch,
        trainval_io.RADAR_FIELDS_KEY: radar_field_names,
        trainval_io.RADAR_HEIGHTS_KEY: RADAR_HEIGHTS_M_AGL,
        trainval_io.SOUNDING_FIELDS_KEY: sounding_field_names,
        trainval_io.SOUNDING_HEIGHTS_KEY: SOUNDING_HEIGHTS_M_AGL,
        trainval_io.NUM_ROWS_KEY: num_grid_rows,
        trainval_io.NUM_COLUMNS_KEY: num_grid_columns,
        trainval_io.NORMALIZATION_TYPE_KEY: normalization_type_string,
        trainval_io.NORMALIZATION_FILE_KEY: normalization_param_file_name,
        trainval_io.MIN_NORMALIZED_VALUE_KEY: min_normalized_value,
        trainval_io.MAX_NORMALIZED_VALUE_KEY: max_normalized_value,
        trainval_io.BINARIZE_TARGET_KEY: False,
        trainval_io.SAMPLING_FRACTIONS_KEY: class_to_sampling_fraction_dict,
        trainval_io.LOOP_ONCE_KEY: False,
        trainval_io.REFLECTIVITY_MASK_KEY: refl_masking_threshold_dbz,
        trainval_io.X_TRANSLATIONS_KEY: x_translations_pixels,
        trainval_io.Y_TRANSLATIONS_KEY: y_translations_pixels,
        trainval_io.ROTATION_ANGLES_KEY: ccw_rotation_angles_deg,
        trainval_io.NOISE_STDEV_KEY: noise_standard_deviation,
        trainval_io.NUM_NOISINGS_KEY: num_noisings,
        trainval_io.FLIP_X_KEY: flip_in_x,
        trainval_io.FLIP_Y_KEY: flip_in_y
    }

    print 'Writing metadata to: "{0:s}"...'.format(model_metafile_name)
    cnn.write_model_metadata(pickle_file_name=model_metafile_name,
                             metadata_dict=metadata_dict,
                             training_option_dict=training_option_dict)

    cnn.train_cnn_2d_or_3d(
        model_object=model_object,
        model_file_name=output_model_file_name,
        history_file_name=history_file_name,
        tensorboard_dir_name=tensorboard_dir_name,
        num_epochs=num_epochs,
        num_training_batches_per_epoch=num_training_batches_per_epoch,
        training_option_dict=training_option_dict,
        monitor_string=monitor_string,
        weight_loss_function=weight_loss_function,
        num_validation_batches_per_epoch=num_validation_batches_per_epoch,
        validation_file_names=validation_file_names,
        first_validn_time_unix_sec=first_validation_time_unix_sec,
        last_validn_time_unix_sec=last_validation_time_unix_sec)
Exemplo n.º 5
0
def _run(input_model_file_name, radar_field_name_by_channel,
         layer_op_name_by_channel, min_height_by_channel_m_agl,
         max_height_by_channel_m_agl, sounding_field_names,
         normalization_type_string, normalization_param_file_name,
         min_normalized_value, max_normalized_value, target_name,
         shuffle_target, downsampling_classes, downsampling_fractions,
         monitor_string, weight_loss_function, x_translations_pixels,
         y_translations_pixels, ccw_rotation_angles_deg,
         noise_standard_deviation, num_noisings, flip_in_x, flip_in_y,
         top_training_dir_name, first_training_time_string,
         last_training_time_string, num_examples_per_train_batch,
         top_validation_dir_name, first_validation_time_string,
         last_validation_time_string, num_examples_per_validn_batch,
         num_epochs, num_training_batches_per_epoch,
         num_validation_batches_per_epoch, output_dir_name):
    """Trains CNN with 2-D GridRad images.

    This is effectively the main method.

    :param input_model_file_name: See documentation at top of file.
    :param radar_field_name_by_channel: Same.
    :param layer_op_name_by_channel: Same.
    :param min_height_by_channel_m_agl: Same.
    :param max_height_by_channel_m_agl: Same.
    :param sounding_field_names: Same.
    :param normalization_type_string: Same.
    :param normalization_param_file_name: Same.
    :param min_normalized_value: Same.
    :param max_normalized_value: Same.
    :param target_name: Same.
    :param shuffle_target: Same.
    :param downsampling_classes: Same.
    :param downsampling_fractions: Same.
    :param monitor_string: Same.
    :param weight_loss_function: Same.
    :param x_translations_pixels: Same.
    :param y_translations_pixels: Same.
    :param ccw_rotation_angles_deg: Same.
    :param noise_standard_deviation: Same.
    :param num_noisings: Same.
    :param flip_in_x: Same.
    :param flip_in_y: Same.
    :param top_training_dir_name: Same.
    :param first_training_time_string: Same.
    :param last_training_time_string: Same.
    :param num_examples_per_train_batch: Same.
    :param top_validation_dir_name: Same.
    :param first_validation_time_string: Same.
    :param last_validation_time_string: Same.
    :param num_examples_per_validn_batch: Same.
    :param num_epochs: Same.
    :param num_training_batches_per_epoch: Same.
    :param num_validation_batches_per_epoch: Same.
    :param output_dir_name: Same.
    """

    # Process input args.
    first_training_time_unix_sec = time_conversion.string_to_unix_sec(
        first_training_time_string, TIME_FORMAT)
    last_training_time_unix_sec = time_conversion.string_to_unix_sec(
        last_training_time_string, TIME_FORMAT)

    if top_validation_dir_name in ['', 'None']:
        top_validation_dir_name = None
        num_validation_batches_per_epoch = 0
        first_validation_time_unix_sec = 0
        last_validation_time_unix_sec = 0
    else:
        first_validation_time_unix_sec = time_conversion.string_to_unix_sec(
            first_validation_time_string, TIME_FORMAT)
        last_validation_time_unix_sec = time_conversion.string_to_unix_sec(
            last_validation_time_string, TIME_FORMAT)

    if sounding_field_names[0] in ['', 'None']:
        sounding_field_names = None

    if len(downsampling_classes) > 1:
        class_to_sampling_fraction_dict = dict(
            list(zip(downsampling_classes, downsampling_fractions)))
    else:
        class_to_sampling_fraction_dict = None

    if (len(x_translations_pixels) == 1
            and x_translations_pixels + y_translations_pixels == 0):
        x_translations_pixels = None
        y_translations_pixels = None

    if len(ccw_rotation_angles_deg) == 1 and ccw_rotation_angles_deg[0] == 0:
        ccw_rotation_angles_deg = None

    if num_noisings <= 0:
        num_noisings = 0
        noise_standard_deviation = None

    num_channels = len(radar_field_name_by_channel)
    expected_dimensions = numpy.array([num_channels], dtype=int)

    error_checking.assert_is_numpy_array(numpy.array(layer_op_name_by_channel),
                                         exact_dimensions=expected_dimensions)

    error_checking.assert_is_numpy_array(min_height_by_channel_m_agl,
                                         exact_dimensions=expected_dimensions)
    error_checking.assert_is_numpy_array(max_height_by_channel_m_agl,
                                         exact_dimensions=expected_dimensions)

    list_of_layer_operation_dicts = [{}] * num_channels
    for m in range(num_channels):
        list_of_layer_operation_dicts[m] = {
            input_examples.RADAR_FIELD_KEY: radar_field_name_by_channel[m],
            input_examples.OPERATION_NAME_KEY: layer_op_name_by_channel[m],
            input_examples.MIN_HEIGHT_KEY: min_height_by_channel_m_agl[m],
            input_examples.MAX_HEIGHT_KEY: max_height_by_channel_m_agl[m]
        }

    # Set output locations.
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=output_dir_name)

    output_model_file_name = '{0:s}/model.h5'.format(output_dir_name)
    history_file_name = '{0:s}/model_history.csv'.format(output_dir_name)
    tensorboard_dir_name = '{0:s}/tensorboard'.format(output_dir_name)
    model_metafile_name = '{0:s}/model_metadata.p'.format(output_dir_name)

    # Find training and validation files.
    training_file_names = input_examples.find_many_example_files(
        top_directory_name=top_training_dir_name,
        shuffled=True,
        first_batch_number=FIRST_BATCH_NUMBER,
        last_batch_number=LAST_BATCH_NUMBER,
        raise_error_if_any_missing=False)

    if top_validation_dir_name is None:
        validation_file_names = []
    else:
        validation_file_names = input_examples.find_many_example_files(
            top_directory_name=top_validation_dir_name,
            shuffled=True,
            first_batch_number=FIRST_BATCH_NUMBER,
            last_batch_number=LAST_BATCH_NUMBER,
            raise_error_if_any_missing=False)

    # Read architecture.
    print(
        'Reading architecture from: "{0:s}"...'.format(input_model_file_name))
    model_object = cnn.read_model(input_model_file_name)
    # model_object = clone_model(model_object)

    # TODO(thunderhoser): This is a HACK.
    model_object.compile(loss=keras.losses.binary_crossentropy,
                         optimizer=keras.optimizers.Adam(),
                         metrics=cnn_setup.DEFAULT_METRIC_FUNCTION_LIST)

    print(SEPARATOR_STRING)
    model_object.summary()
    print(SEPARATOR_STRING)

    print(K.eval(model_object.get_layer(name='radar_conv2d_2').weights[0]))
    print(SEPARATOR_STRING)

    # Write metadata.
    metadata_dict = {
        cnn.NUM_EPOCHS_KEY: num_epochs,
        cnn.NUM_TRAINING_BATCHES_KEY: num_training_batches_per_epoch,
        cnn.NUM_VALIDATION_BATCHES_KEY: num_validation_batches_per_epoch,
        cnn.MONITOR_STRING_KEY: monitor_string,
        cnn.WEIGHT_LOSS_FUNCTION_KEY: weight_loss_function,
        cnn.CONV_2D3D_KEY: False,
        cnn.VALIDATION_FILES_KEY: validation_file_names,
        cnn.FIRST_VALIDN_TIME_KEY: first_validation_time_unix_sec,
        cnn.LAST_VALIDN_TIME_KEY: last_validation_time_unix_sec,
        cnn.NUM_EX_PER_VALIDN_BATCH_KEY: num_examples_per_validn_batch
    }

    input_tensor = model_object.input
    if isinstance(input_tensor, list):
        input_tensor = input_tensor[0]

    num_grid_rows = input_tensor.get_shape().as_list()[1]
    num_grid_columns = input_tensor.get_shape().as_list()[2]

    training_option_dict = {
        trainval_io.EXAMPLE_FILES_KEY: training_file_names,
        trainval_io.TARGET_NAME_KEY: target_name,
        trainval_io.SHUFFLE_TARGET_KEY: shuffle_target,
        trainval_io.FIRST_STORM_TIME_KEY: first_training_time_unix_sec,
        trainval_io.LAST_STORM_TIME_KEY: last_training_time_unix_sec,
        trainval_io.NUM_EXAMPLES_PER_BATCH_KEY: num_examples_per_train_batch,
        trainval_io.SOUNDING_FIELDS_KEY: sounding_field_names,
        trainval_io.SOUNDING_HEIGHTS_KEY: SOUNDING_HEIGHTS_M_AGL,
        trainval_io.NUM_ROWS_KEY: num_grid_rows,
        trainval_io.NUM_COLUMNS_KEY: num_grid_columns,
        trainval_io.NORMALIZATION_TYPE_KEY: normalization_type_string,
        trainval_io.NORMALIZATION_FILE_KEY: normalization_param_file_name,
        trainval_io.MIN_NORMALIZED_VALUE_KEY: min_normalized_value,
        trainval_io.MAX_NORMALIZED_VALUE_KEY: max_normalized_value,
        trainval_io.BINARIZE_TARGET_KEY: False,
        trainval_io.SAMPLING_FRACTIONS_KEY: class_to_sampling_fraction_dict,
        trainval_io.LOOP_ONCE_KEY: False,
        trainval_io.X_TRANSLATIONS_KEY: x_translations_pixels,
        trainval_io.Y_TRANSLATIONS_KEY: y_translations_pixels,
        trainval_io.ROTATION_ANGLES_KEY: ccw_rotation_angles_deg,
        trainval_io.NOISE_STDEV_KEY: noise_standard_deviation,
        trainval_io.NUM_NOISINGS_KEY: num_noisings,
        trainval_io.FLIP_X_KEY: flip_in_x,
        trainval_io.FLIP_Y_KEY: flip_in_y
    }

    print('Writing metadata to: "{0:s}"...'.format(model_metafile_name))
    cnn.write_model_metadata(
        pickle_file_name=model_metafile_name,
        metadata_dict=metadata_dict,
        training_option_dict=training_option_dict,
        list_of_layer_operation_dicts=list_of_layer_operation_dicts)

    cnn.train_cnn_gridrad_2d_reduced(
        model_object=model_object,
        model_file_name=output_model_file_name,
        history_file_name=history_file_name,
        tensorboard_dir_name=tensorboard_dir_name,
        num_epochs=num_epochs,
        num_training_batches_per_epoch=num_training_batches_per_epoch,
        training_option_dict=training_option_dict,
        list_of_layer_operation_dicts=list_of_layer_operation_dicts,
        monitor_string=monitor_string,
        weight_loss_function=weight_loss_function,
        num_validation_batches_per_epoch=num_validation_batches_per_epoch,
        validation_file_names=validation_file_names,
        first_validn_time_unix_sec=first_validation_time_unix_sec,
        last_validn_time_unix_sec=last_validation_time_unix_sec,
        num_examples_per_validn_batch=num_examples_per_validn_batch)