示例#1
0
def _check_training_args(
        model_file_name, history_file_name, tensorboard_dir_name, num_epochs,
        num_training_batches_per_epoch, num_validation_batches_per_epoch,
        training_option_dict, weight_loss_function):
    """Error-checks input arguments for training.

    :param model_file_name: Path to output file (HDF5 format).  The model will
        be saved here after each epoch.
    :param history_file_name: Path to output file (CSV format).  Training
        history (performance metrics) will be saved here after each epoch.
    :param tensorboard_dir_name: Path to output directory for TensorBoard log
        files.
    :param num_epochs: Number of epochs.
    :param num_training_batches_per_epoch: Number of training batches in each
        epoch.
    :param num_validation_batches_per_epoch: Number of validation batches in
        each epoch.
    :param training_option_dict: See doc for
        `training_validation_io.example_generator_2d_or_3d`.
    :param weight_loss_function: Boolean flag.  If False, classes will be
        weighted equally in the loss function.  If True, classes will be
        weighted differently (inversely proportional to their sampling
        fractions).
    :return: class_to_weight_dict: Dictionary, where each key is the integer ID
        for a target class (-2 for "dead storm") and each value is the weight
        for the loss function.  If None, classes will be equally weighted in the
        loss function.
    """

    orig_option_dict = training_option_dict.copy()
    training_option_dict = trainval_io.DEFAULT_OPTION_DICT.copy()
    training_option_dict.update(orig_option_dict)

    file_system_utils.mkdir_recursive_if_necessary(file_name=model_file_name)
    file_system_utils.mkdir_recursive_if_necessary(file_name=history_file_name)
    file_system_utils.mkdir_recursive_if_necessary(
        directory_name=tensorboard_dir_name)

    error_checking.assert_is_integer(num_epochs)
    error_checking.assert_is_geq(num_epochs, 1)
    error_checking.assert_is_integer(num_training_batches_per_epoch)
    error_checking.assert_is_geq(num_training_batches_per_epoch, 1)
    error_checking.assert_is_integer(num_validation_batches_per_epoch)
    error_checking.assert_is_geq(num_validation_batches_per_epoch, 0)

    error_checking.assert_is_boolean(weight_loss_function)
    if not weight_loss_function:
        return None

    class_to_sampling_fraction_dict = training_option_dict[
        trainval_io.SAMPLING_FRACTIONS_KEY
    ]
    if class_to_sampling_fraction_dict is None:
        return None

    return dl_utils.class_fractions_to_weights(
        sampling_fraction_by_class_dict=class_to_sampling_fraction_dict,
        target_name=training_option_dict[trainval_io.TARGET_NAME_KEY],
        binarize_target=training_option_dict[trainval_io.BINARIZE_TARGET_KEY]
    )
    def test_class_fractions_to_weights_wind_7class_binarized(self):
        """Ensures correct output from class_fractions_to_weights.

        In this case, target variable = wind-speed class; there are 7 classes;
        and binarize_target = True.
        """

        this_dict = dl_utils.class_fractions_to_weights(
            sampling_fraction_by_class_dict=
            SAMPLING_FRACTION_BY_WIND_7CLASS_DICT,
            target_name=WIND_TARGET_NAME_7CLASSES,
            binarize_target=True)

        self.assertTrue(this_dict == LF_WEIGHT_BY_WIND_7CLASS_DICT_BINARIZED)
    def test_class_fractions_to_weights_tornado_nonbinarized(self):
        """Ensures correct output from class_fractions_to_weights.

        In this case, target variable = tornado occurrence and binarize_target =
        False.
        """

        this_dict = dl_utils.class_fractions_to_weights(
            sampling_fraction_by_class_dict=
            SAMPLING_FRACTION_BY_TORNADO_CLASS_DICT,
            target_name=TORNADO_TARGET_NAME,
            binarize_target=False)

        self.assertTrue(this_dict == LF_WEIGHT_BY_TORNADO_CLASS_DICT)
    def test_class_fractions_to_weights_tornado_binarized(self):
        """Ensures correct output from class_fractions_to_weights.

        In this case, target variable = tornado occurrence and binarize_target =
        True.
        """

        this_dict = dl_utils.class_fractions_to_weights(
            sampling_fraction_by_class_dict=
            SAMPLING_FRACTION_BY_TORNADO_CLASS_DICT,
            target_name=TORNADO_TARGET_NAME,
            binarize_target=True)

        expected_keys = list(LF_WEIGHT_BY_TORNADO_CLASS_DICT.keys())
        actual_keys = list(this_dict.keys())
        self.assertTrue(set(expected_keys) == set(actual_keys))

        for this_key in expected_keys:
            self.assertTrue(
                numpy.isclose(LF_WEIGHT_BY_TORNADO_CLASS_DICT[this_key],
                              this_dict[this_key],
                              atol=TOLERANCE_FOR_CLASS_WEIGHT))