def test_check_stopping_criterion_wind_need_examples(self): """Ensures correct output from _check_stopping_criterion. In this case, target variable = wind and more examples are needed. """ this_flag = trainval_io._check_stopping_criterion( num_examples_per_batch=NUM_EXAMPLES_PER_BATCH, class_to_batch_size_dict=CLASS_TO_BATCH_SIZE_DICT_WIND, class_to_sampling_fraction_dict=None, target_values_in_memory=TARGET_VALUES_50ZEROS) self.assertFalse(this_flag)
def test_check_stopping_criterion_wind_have_classes(self): """Ensures correct output from _check_stopping_criterion. In this case, target variable = wind and downsampling is on. But this doesn't matter, because we have enough examples from each class. """ this_flag = trainval_io._check_stopping_criterion( num_examples_per_batch=NUM_EXAMPLES_PER_BATCH, class_to_batch_size_dict=CLASS_TO_BATCH_SIZE_DICT_WIND, class_to_sampling_fraction_dict=DOWNSAMPLING_DICT_WIND, target_values_in_memory=TARGET_VALUES_WIND) self.assertTrue(this_flag)
def test_check_stopping_criterion_wind_no_downsampling(self): """Ensures correct output from _check_stopping_criterion. In this case, target variable = wind and all examples have target = 0. However, this doesn't matter, because downsampling = no. """ this_flag = trainval_io._check_stopping_criterion( num_examples_per_batch=NUM_EXAMPLES_PER_BATCH, class_to_batch_size_dict=CLASS_TO_BATCH_SIZE_DICT_WIND, class_to_sampling_fraction_dict=None, target_values_in_memory=TARGET_VALUES_200ZEROS) self.assertTrue(this_flag)
def test_check_stopping_criterion_wind_need_classes(self): """Ensures correct output from _check_stopping_criterion. In this case, target variable = wind and all examples have target = 0. This will make stopping criterion = False, because downsampling is on. """ this_flag = trainval_io._check_stopping_criterion( num_examples_per_batch=NUM_EXAMPLES_PER_BATCH, class_to_batch_size_dict=CLASS_TO_BATCH_SIZE_DICT_WIND, class_to_sampling_fraction_dict=DOWNSAMPLING_DICT_WIND, target_values_in_memory=TARGET_VALUES_200ZEROS) self.assertFalse(this_flag)