Esempio n. 1
0
    def test_import_data(self):
        """Test the generalized import data function"""
        # File test
        data, var_names = utilities.import_data(self.data_path)
        self.assertEqual(var_names, self.expected_var_names)
        assert_array_equal(data, self.expected_arr)

        # dict test
        data_dict = utilities.csv_to_dict(self.data_path, dtype=float)
        data_arr, var_names = utilities.import_data(data_dict)
        self.assertEqual(var_names, self.expected_var_names)
        assert_array_equal(data_arr, self.expected_arr)

        # array test
        data_arr_2, var_names = utilities.import_data(data_arr)
        self.assertEqual(var_names, list(range(len(data_arr[0, :]))))
        assert_array_equal(data_arr_2, self.expected_arr)

        with self.assertRaises(NotImplementedError):
            utilities.import_data([0, 1])
Esempio n. 2
0
 def test_no_header_row(self):
     """Test import with no header row"""
     data = utilities.csv_to_dict(self.data_path_nh,
                                  dtype=float,
                                  header_row=False)
     self.assertEqual(data, self.expected_dict_nh)
Esempio n. 3
0
 def test_csv_to_dict(self):
     """Test csv_to_dict"""
     data = utilities.csv_to_dict(self.data_path, dtype=float)
     self.assertEqual(data, self.expected_dict)
Esempio n. 4
0
 def test_dict_to_array(self):
     """Test dict_to_array"""
     dict_data = utilities.csv_to_dict(self.data_path, dtype=float)
     arr = utilities.dict_to_array(dict_data)
     assert_array_equal(arr["data"], self.expected_arr)
     self.assertEqual(arr["var_names"], self.expected_var_names)
Esempio n. 5
0
    def test_widen_data(self):
        """Test widen data script"""
        data_dict = utilities.csv_to_dict(example_narrow_data)
        kwargs = {
            "uid_columns": ["patient", "plan", "field id"],
            "x_data_cols": ["DD(%)", "DTA(mm)", "Threshold(%)"],
            "y_data_col": "Gamma Pass Rate(%)",
            "date_col": "date",
            "dtype": float,
        }
        ds = utilities.widen_data(data_dict, **kwargs)

        expected = {
            "uid":
            ["ANON1234 && Plan_name && 3", "ANON1234 && Plan_name && 4"],
            "date": ["6/13/2019 7:27", "6/13/2019 7:27"],
            "2.0 && 3.0 && 10.0": [np.nan, 99.99476863],
            "2.0 && 3.0 && 5.0": [99.88772435, 99.99533258],
            "3.0 && 2.0 && 10.0": [99.94708217, 99.99941874],
            "3.0 && 3.0 && 10.0": [99.97706894, 100],
            "3.0 && 3.0 && 5.0": [99.97934552, 100],
        }

        for key, exp_value in expected.items():
            assert_array_equal(ds[key], exp_value)

        # No date test
        kwargs_no_date = {key: value for key, value in kwargs.items()}
        kwargs_no_date["date_col"] = None
        ds_2 = utilities.widen_data(data_dict, **kwargs_no_date)
        for key, ds_2_value in ds_2.items():
            assert_array_equal(ds_2_value, expected[key])

        # test column length check
        data_dict_2 = utilities.csv_to_dict(example_narrow_data)
        data_dict_2[list(data_dict_2)[0]].append("test")
        with self.assertRaises(NotImplementedError):
            utilities.widen_data(data_dict_2, **kwargs)

        # test policy check
        with self.assertRaises(NotImplementedError):
            utilities.widen_data(data_dict, multi_val_policy="test", **kwargs)

        ds = utilities.widen_data(data_dict, multi_val_policy="last", **kwargs)
        index = ds["uid"].index("ANON1234 && Plan_name && 4")
        self.assertEqual(ds["2.0 && 3.0 && 10.0"][index], 50)

        ds = utilities.widen_data(data_dict, multi_val_policy="min", **kwargs)
        index = ds["uid"].index("ANON1234 && Plan_name && 4")
        self.assertEqual(ds["2.0 && 3.0 && 10.0"][index], 50)

        ds = utilities.widen_data(data_dict, multi_val_policy="max", **kwargs)
        index = ds["uid"].index("ANON1234 && Plan_name && 4")
        self.assertEqual(
            ds["2.0 && 3.0 && 10.0"][index],
            expected["2.0 && 3.0 && 10.0"][index],
        )

        ds = utilities.widen_data(data_dict,
                                  remove_partial_columns=True,
                                  **kwargs)
        self.assertTrue("2.0 && 3.0 && 10.0" not in ds.keys())