def test_add_subject_information(self):
        patients_1, _ = self.create_spark_df({"patientID": [1, 2]})
        events_1, _ = self.create_spark_df({
            "patientID": [1, 2],
            "value": ["fracture", "fracture"]
        })
        input1 = Cohort("liberal_fractures", "liberal_fractures", patients_1,
                        events_1)
        patients_2, _ = self.create_spark_df({
            "patientID": [1, 2],
            "gender": [1, 1],
            "birthDate": [
                pd.to_datetime("1993-10-09"),
                pd.to_datetime("1992-03-14"),
            ],
            "deathDate": [
                pd.to_datetime("1993-10-09"),
                pd.to_datetime("1992-03-14"),
            ],
        })

        base_cohort1 = Cohort("patients", "patients", patients_2)
        input1.add_subject_information(base_cohort1, "error")

        self.assertTrue(input1.has_subject_information()
                        and input1.subjects.count() == 2)

        patients_3, _ = self.create_spark_df({
            "patientID": [1],
            "gender": [1],
            "birthDate": [pd.to_datetime("1993-10-09")],
            "deathDate": [pd.to_datetime("1993-10-09")],
        })
        base_cohort2 = Cohort("liberal_fractures", "liberal_fractures",
                              patients_3, None)
        input2 = Cohort("liberal_fractures", "liberal_fractures", patients_1,
                        events_1)
        input2.add_subject_information(base_cohort2, "omit")
        self.assertTrue(input2.has_subject_information()
                        and input2.subjects.count() == 1
                        and input2.events.count() == 2)

        input3 = Cohort("liberal_fractures", "liberal_fractures", patients_1,
                        events_1)
        input3.add_subject_information(base_cohort2, "omit_all")
        self.assertTrue(input3.has_subject_information()
                        and input3.subjects.count() == 1
                        and input3.events.count() == 1)
    def test_has_subject_information(self):
        patients_1, _ = self.create_spark_df({"patientID": [1, 2]})
        cohort1 = Cohort("liberal_fractures", "liberal_fractures", patients_1)
        patients_2, _ = self.create_spark_df({
            "patientID": [1, 2],
            "gender": [1, 1],
            "birthDate": [
                pd.to_datetime("1993-10-09"),
                pd.to_datetime("1992-03-14"),
            ],
            "deathDate": [
                pd.to_datetime("1993-10-09"),
                pd.to_datetime("1992-03-14"),
            ],
        })

        cohort2 = Cohort("liberal_fractures", "liberal_fractures", patients_2,
                         None)

        self.assertFalse(cohort1.has_subject_information())
        self.assertTrue(cohort2.has_subject_information())
Beispiel #3
0
 def _find_subjects_with_age_inconsistent_w_age_groups(
         self, cohort: Cohort) -> Cohort:
     """Check if min and max age_groups are consistent with subjects ages."""
     if not cohort.has_subject_information():
         raise ValueError("Cohort should have subject information.")
     duplicate = copy(cohort)
     duplicate.add_age_information(
         self.age_reference_date)  # add starting age
     study_length = (np.ceil(
         (self.study_end - self.study_start).days /
         365.25) if self.is_using_longitudinal_age_groups else 0)
     min_starting_age = min(self.age_groups)
     max_starting_age = max(self.age_groups) - np.ceil(study_length)
     invalid_subjects = duplicate.subjects.where(
         ~sf.col("age").between(min_starting_age, max_starting_age))
     return Cohort(
         cohort.name + "_inconsistent_w_ages_and_age_groups",
         "subjects inconsistent with age groups",
         invalid_subjects,
     )
    def _compute_longitudinal_age_groups(
            self, cohort: Cohort,
            col_offset: int) -> Tuple[DataFrame, List[str]]:
        """
        Parameters
        ----------
        cohort: Cohort
            cohort on which the age groups should be computed
        col_offset: int
            number of columns used by lagged exposure features

        Returns
        -------
        (age_features, mapping): Tuple(DataFrame, List(str))
            a dataframe containing the age features in aij format and a mapping giving
            the correspondence between column number and age group.
        """
        # This implementation is suboptimal, but we need to have something
        # working with inconsistent python versions across the cluster.
        assert (cohort.has_subject_information(
        )), "Cohort subjects should have gender and birthdate information"

        subjects = cohort.subjects.select("patientID", "gender", "birthDate")

        bucket_ids = sf.array([sf.lit(i) for i in range(self.n_buckets)])
        subjects = (subjects.withColumn("bucketID", bucket_ids).select(
            "PatientID",
            "gender",
            "birthDate",
            sf.explode("bucketID").alias("bucket"),
        ).withColumn("dateShift",
                     sf.col("bucket") * self.bucket_size).withColumn(
                         "referenceDate", sf.lit(self.age_reference_date)))
        # Longitudinal age is based on referenceDate instead of minDate to
        # be consistent with cohort definition.
        time_references = sf.expr("date_add(referenceDate, dateShift)")
        longitudinal_age = sf.floor(
            sf.months_between(time_references, sf.col("birthdate")) / 12)
        subjects = subjects.withColumn("longitudinalAge",
                                       longitudinal_age).select(
                                           "patientID", "gender", "birthDate",
                                           "bucket", "longitudinalAge")

        subjects, n_age_groups, mapping = self._bucketize_age_column(
            subjects, "longitudinalAge", "longitudinalAgeBucket")

        assert n_age_groups == self.n_age_groups, (
            "Computed number of age groups is different from the number of specified"
            " age groups at initialization. There might be empty age_groups,"
            " you should investigate this.")

        age_features = subjects.select(
            sf.col("patientID"),
            sf.col("bucket").alias("rowIndex"),
            (sf.col("longitudinalAgeBucket") + col_offset).alias("colIndex"),
        )

        # Remove "age events" which are not in follow-up
        fup_events = self.followups.intersection(self.final_cohort).events
        fup_events = self._discretize_start_end(fup_events)
        fup_events = rename_df_columns(fup_events, prefix="fup_")
        age_features_columns = age_features.columns
        age_features = age_features.join(fup_events, on="patientID")
        age_features = age_features.where(
            sf.col("rowIndex").between(sf.col("fup_startBucket"),
                                       sf.col("fup_endBucket")))
        age_features = age_features.select(*age_features_columns)

        return age_features, mapping