예제 #1
0
    def test_calc_flat_whist_noclip(self):
        """Test flat weight whist calculation without clipping"""
        bins = 4
        bins_range = (0, 200)
        bins_null = [0, 50, 100, 150, 200]

        values_null = [1, 11, 111, 121, 4]
        # bins: [ (0, 50), (50, 100), (100, 150), (150, 200) ]
        # hist_null     = [ 3,    0,     2,    0 ]
        #
        # ( regularized = hist + 1 )
        # hist_null_reg = [ 4,    1,     3,    1 ]
        #
        # ( whist       = 1 / hist_reg )
        # whist_null    = [ 1/4,  1/1,   1/3,  1/1 ]
        #
        # norm = sum(whist) = 1/4 + 1 + 1/3 + 1
        # norm_null     = 31 / 12
        # ( whist_null    = whist_null / norm_null )
        whist_null = [3 / 31, 12 / 31, 4 / 31, 12 / 31]

        (values_test, whist_test,
         bins_test) = calc_flat_whist(DictLoader({'weight': values_null}),
                                      var='weight',
                                      bins=bins,
                                      range=bins_range,
                                      clip=None)

        self.assertTrue(nan_equal(values_test, values_null))
        self.assertTrue(nan_equal(whist_test, whist_null))
        self.assertTrue(nan_equal(bins_test, bins_null))
예제 #2
0
    def test_varr_var(self):
        """Test slicing of a single variable length array variable"""
        data = {'var': [[1, 2], [], [3], [4, 5, 6, 7], [-1]]}
        slice_index = [0, 4]
        slice_data = {'var': [[1, 2], [-1]]}
        data_loader = DataSlice(DictLoader(data), slice_index)

        self._compare_varr_vars(slice_data, data_loader, 'var')
예제 #3
0
    def test_scalar_var(self):
        """Test slicing of a single scalar variable"""
        data = {'var': [1, 2, 3, 4, -1]}
        slice_index = [0, 2, 3]
        slice_data = {'var': [1, 3, 4]}
        data_loader = DataSlice(DictLoader(data), slice_index)

        self._compare_scalar_vars(slice_data, data_loader, 'var')
    def _make_dgen(pdg_list, iscc_list, target_pdg_iscc_list):

        data_loader = DictLoader({'pdg': pdg_list, 'iscc': iscc_list})

        return DataGenerator(
            data_loader,
            10,
            target_pdg_iscc_list=target_pdg_iscc_list,
            var_target_pdg='pdg',
            var_target_iscc='iscc',
        )
예제 #5
0
    def test_filter_wildcard_iscc(self):
        """Test filtering with wildcard ISCC pattern"""
        data = {
            'pdg': [1, 2, 0, 1, 2],
            'iscc': [0, 1, 0, 1, 0],
            'idx': [0, 1, 2, 3, 4],
        }
        keep_pdg_iscc_list = [(1, None)]
        slice_data = {'idx': [0, 3]}
        data_loader = DataFilter(DictLoader(data), 'pdg', 'iscc',
                                 keep_pdg_iscc_list)

        self._compare_scalar_vars(slice_data, data_loader, 'idx')
예제 #6
0
    def test_filter_simple(self):
        """Simple filtering tests"""
        data = {
            'pdg': [1, 2, 0, 1, 2],
            'iscc': [0, 1, 0, 1, 0],
            'idx': [0, 1, 2, 3, 4],
        }
        keep_pdg_iscc_list = [(0, 0), (1, 0)]
        slice_data = {'idx': [0, 2]}
        data_loader = DataFilter(DictLoader(data), 'pdg', 'iscc',
                                 keep_pdg_iscc_list)

        self._compare_scalar_vars(slice_data, data_loader, 'idx')
예제 #7
0
    def test_filter_pass_none(self):
        """Test filtering that should reject all samples"""
        data = {
            'pdg': [1, 2, 0, 1, 2],
            'iscc': [0, 1, 0, 1, 0],
            'idx': [0, 1, 2, 3, 4],
        }
        keep_pdg_iscc_list = [(-1, -1)]
        slice_data = {'idx': []}
        data_loader = DataFilter(DictLoader(data), 'pdg', 'iscc',
                                 keep_pdg_iscc_list)

        self._compare_scalar_vars(slice_data, data_loader, 'idx')
예제 #8
0
    def test_filter_pass_all(self):
        """Test filtering that should not filter anything"""
        data = {
            'pdg': [1, 2, 0, 1, 2],
            'iscc': [0, 1, 0, 1, 0],
            'idx': [0, 1, 2, 3, 4],
        }
        keep_pdg_iscc_list = [(0, 0), (1, 0), (1, 1), (2, 0), (2, 1)]
        slice_data = {'idx': [0, 1, 2, 3, 4]}
        data_loader = DataFilter(DictLoader(data), 'pdg', 'iscc',
                                 keep_pdg_iscc_list)

        self._compare_scalar_vars(slice_data, data_loader, 'idx')
예제 #9
0
    def test_filter_missing_value(self):
        """Test filtering with filter that does not match anything"""
        data = {
            'pdg': [1, 2, 0, 1, 2],
            'iscc': [0, 1, 0, 1, 0],
            'idx': [0, 1, 2, 3, 4],
        }
        keep_pdg_iscc_list = [(0, 0), (-1, 0)]
        slice_data = {'idx': [2]}
        data_loader = DataFilter(DictLoader(data), 'pdg', 'iscc',
                                 keep_pdg_iscc_list)

        self._compare_scalar_vars(slice_data, data_loader, 'idx')
예제 #10
0
    def test_varr_var(self):
        """Test shuffling of a single variable length array variable"""
        seed = 321
        data = {'var': np.array([[1, 2], [], [3], [4, 5, 6, 7], [-1]])}
        data_loader = DataShuffle(DictLoader(data), seed)

        indices_original = np.arange(0, len(data['var']))
        indices_shuffled = np.array(indices_original[:])

        np.random.seed(seed)
        np.random.shuffle(indices_shuffled)
        self.assertTrue(
            np.any(~np.isclose(indices_original, indices_shuffled)))

        data_shuffled = {'var': data['var'][indices_shuffled]}
        self._compare_varr_vars(data_shuffled, data_loader, 'var')
예제 #11
0
    def test_scalar_var(self):
        """Test shuffling of a single scalar variable"""
        seed = 1223
        data = {'var': np.array([1, 2, 3, 4, -1])}
        data_loader = DataShuffle(DictLoader(data), seed)

        indices_original = np.arange(0, len(data['var']))
        indices_shuffled = np.array(indices_original[:])

        np.random.seed(seed)
        np.random.shuffle(indices_shuffled)
        self.assertTrue(
            np.any(~np.isclose(indices_original, indices_shuffled)))

        data_shuffled = {'var': data['var'][indices_shuffled]}
        self._compare_scalar_vars(data_shuffled, data_loader, 'var')
def make_data_generator(data_loader=DictLoader(TEST_DATA),
                        vars_input_slice=TEST_INPUT_VARS_SLICE,
                        vars_input_png3d=TEST_INPUT_VARS_PNG3D,
                        var_target_pdg=TEST_TARGET_VAR_PDG,
                        var_target_iscc=TEST_TARGET_VAR_ISCC,
                        **kwargs):
    """Create simple `DataGenerator`"""
    # pylint: disable=dangerous-default-value

    return DataGenerator(
        data_loader=data_loader,
        vars_input_slice=vars_input_slice,
        vars_input_png3d=vars_input_png3d,
        var_target_pdg=var_target_pdg,
        var_target_iscc=var_target_iscc,
        **kwargs,
    )
예제 #13
0
def make_data_generator(data_loader=DictLoader(TEST_DATA),
                        vars_input_slice=TEST_INPUT_VARS_SLICE,
                        vars_input_png3d=TEST_INPUT_VARS_PNG3D,
                        vars_input_png2d=TEST_INPUT_VARS_PNG2D,
                        var_target_total=TEST_TARGET_VAR_TOTAL,
                        var_target_primary=TEST_TARGET_VAR_PRIMARY,
                        **kwargs):
    """Create simple `DataGenerator`"""
    # pylint: disable=dangerous-default-value

    return DataGenerator(
        data_loader=data_loader,
        vars_input_slice=vars_input_slice,
        vars_input_png3d=vars_input_png3d,
        vars_input_png2d=vars_input_png2d,
        var_target_total=var_target_total,
        var_target_primary=var_target_primary,
        **kwargs,
    )
예제 #14
0
    def test_flat_weights_noclip(self):
        """Test flat weight calculation without clipping"""
        bins = 4
        bins_range = (0, 200)
        #bins_null   = [ 0, 50, 100, 150, 200 ]

        # bins: [ (0, 50), (50, 100), (100, 150), (150, 200) ]
        values_null = [1, 11, 111, 121, 4]
        values_bins = [0, 0, 2, 2, 0]
        whist_null = [3 / 31, 12 / 31, 4 / 31, 12 / 31]

        weights_null = [whist_null[i] for i in values_bins]
        weights_null = [
            x * len(values_null) / sum(weights_null) for x in weights_null
        ]

        weights_test = flat_weights(DictLoader({'weight': values_null}),
                                    var='weight',
                                    bins=bins,
                                    range=bins_range,
                                    clip=None)

        self.assertTrue(nan_equal(weights_test, weights_null))
예제 #15
0
    def test_calc_flat_whist(self):
        """Test flat weight whist calculation with clipping"""
        clip = 2
        bins = 4
        bins_range = (0, 200)
        bins_null = [0, 50, 100, 150, 200]

        values_null = [1, 11, 111, 121, 4]
        # bins: [ (0, 50), (50, 100), (100, 150), (150, 200) ]
        # hist_null     = [ 3,    0,    2,    0 ]
        #
        # ( regularized = hist + 1 )
        # hist_null_reg = [ 4,    1,    3,    1 ]
        #
        # ( whist = 1 / hist_reg )
        # whist_null    = [ 1/4,  1/1,  1/3,  1/1 ]
        #
        # Adding clipping.
        # Max value is min(whist_null) * clip == 1/4 * 2 = 1/2
        # whist_null    = [ 1/4,  1/2,  1/3,  1/2 ]
        #
        # norm = sum(whist) = 1/4 + 1 + 1/3 + 1
        # norm_null     = 19 / 12
        # ( whist_null = whist_null / norm_null )
        whist_null = [3 / 19, 6 / 19, 4 / 19, 6 / 19]

        (values_test, whist_test,
         bins_test) = calc_flat_whist(DictLoader({'weight': values_null}),
                                      var='weight',
                                      bins=bins,
                                      range=bins_range,
                                      clip=clip)

        self.assertTrue(nan_equal(values_test, values_null))
        self.assertTrue(nan_equal(whist_test, whist_null))
        self.assertTrue(nan_equal(bins_test, bins_null))
예제 #16
0
 def make_balanced_sampler(data, pdg_iscc_list, seed, pdg_signed=False):
     # pylint: disable=unused-argument
     """Construct simple `BalancedSampler` from dict data"""
     return BalancedSampler(DictLoader(data), 'pdg', 'iscc', pdg_iscc_list,
                            seed)
예제 #17
0
 def _create_data_loader(self, data):
     return DictLoader(data)