def _check_training_args( model_file_name, history_file_name, tensorboard_dir_name, num_epochs, num_training_batches_per_epoch, num_validation_batches_per_epoch, training_option_dict, weight_loss_function): """Error-checks input arguments for training. :param model_file_name: Path to output file (HDF5 format). The model will be saved here after each epoch. :param history_file_name: Path to output file (CSV format). Training history (performance metrics) will be saved here after each epoch. :param tensorboard_dir_name: Path to output directory for TensorBoard log files. :param num_epochs: Number of epochs. :param num_training_batches_per_epoch: Number of training batches in each epoch. :param num_validation_batches_per_epoch: Number of validation batches in each epoch. :param training_option_dict: See doc for `training_validation_io.example_generator_2d_or_3d`. :param weight_loss_function: Boolean flag. If False, classes will be weighted equally in the loss function. If True, classes will be weighted differently (inversely proportional to their sampling fractions). :return: class_to_weight_dict: Dictionary, where each key is the integer ID for a target class (-2 for "dead storm") and each value is the weight for the loss function. If None, classes will be equally weighted in the loss function. """ orig_option_dict = training_option_dict.copy() training_option_dict = trainval_io.DEFAULT_OPTION_DICT.copy() training_option_dict.update(orig_option_dict) file_system_utils.mkdir_recursive_if_necessary(file_name=model_file_name) file_system_utils.mkdir_recursive_if_necessary(file_name=history_file_name) file_system_utils.mkdir_recursive_if_necessary( directory_name=tensorboard_dir_name) error_checking.assert_is_integer(num_epochs) error_checking.assert_is_geq(num_epochs, 1) error_checking.assert_is_integer(num_training_batches_per_epoch) error_checking.assert_is_geq(num_training_batches_per_epoch, 1) error_checking.assert_is_integer(num_validation_batches_per_epoch) error_checking.assert_is_geq(num_validation_batches_per_epoch, 0) error_checking.assert_is_boolean(weight_loss_function) if not weight_loss_function: return None class_to_sampling_fraction_dict = training_option_dict[ trainval_io.SAMPLING_FRACTIONS_KEY ] if class_to_sampling_fraction_dict is None: return None return dl_utils.class_fractions_to_weights( sampling_fraction_by_class_dict=class_to_sampling_fraction_dict, target_name=training_option_dict[trainval_io.TARGET_NAME_KEY], binarize_target=training_option_dict[trainval_io.BINARIZE_TARGET_KEY] )
def test_class_fractions_to_weights_wind_7class_binarized(self): """Ensures correct output from class_fractions_to_weights. In this case, target variable = wind-speed class; there are 7 classes; and binarize_target = True. """ this_dict = dl_utils.class_fractions_to_weights( sampling_fraction_by_class_dict= SAMPLING_FRACTION_BY_WIND_7CLASS_DICT, target_name=WIND_TARGET_NAME_7CLASSES, binarize_target=True) self.assertTrue(this_dict == LF_WEIGHT_BY_WIND_7CLASS_DICT_BINARIZED)
def test_class_fractions_to_weights_tornado_nonbinarized(self): """Ensures correct output from class_fractions_to_weights. In this case, target variable = tornado occurrence and binarize_target = False. """ this_dict = dl_utils.class_fractions_to_weights( sampling_fraction_by_class_dict= SAMPLING_FRACTION_BY_TORNADO_CLASS_DICT, target_name=TORNADO_TARGET_NAME, binarize_target=False) self.assertTrue(this_dict == LF_WEIGHT_BY_TORNADO_CLASS_DICT)
def test_class_fractions_to_weights_tornado_binarized(self): """Ensures correct output from class_fractions_to_weights. In this case, target variable = tornado occurrence and binarize_target = True. """ this_dict = dl_utils.class_fractions_to_weights( sampling_fraction_by_class_dict= SAMPLING_FRACTION_BY_TORNADO_CLASS_DICT, target_name=TORNADO_TARGET_NAME, binarize_target=True) expected_keys = list(LF_WEIGHT_BY_TORNADO_CLASS_DICT.keys()) actual_keys = list(this_dict.keys()) self.assertTrue(set(expected_keys) == set(actual_keys)) for this_key in expected_keys: self.assertTrue( numpy.isclose(LF_WEIGHT_BY_TORNADO_CLASS_DICT[this_key], this_dict[this_key], atol=TOLERANCE_FOR_CLASS_WEIGHT))