Skip to content

eds_scikit.plot.age_pyramid

plot_age_pyramid

plot_age_pyramid(person: DataFrame, datetime_ref: datetime = None, return_array: bool = False) -> Tuple[alt.ConcatChart, Series]

Plot an age pyramid from a 'person' pandas DataFrame.

PARAMETER DESCRIPTION
person

The person table. Must have the following columns: - birth_datetime, dtype : datetime or str - person_id, dtype : any - gender_source_value, dtype : str, {'m', 'f'}

TYPE: pd.DataFrame (ks.DataFrame not supported),

datetime_ref : Union[datetime, str], default None The reference date to compute population age from. If a string, it searches for a column with the same name in the person table: each patient has his own datetime reference. If a datetime, the reference datetime is the same for all patients. If set to None, datetime.today() will be used instead.

filename : str, default None The path to save figure at.

savefig : bool, default False If set to True, filename must be set. The plot will be saved at the specified filename.

return_array : bool, default False If set to True, return chart and its pd.Dataframe representation.

RETURNS DESCRIPTION
chart

If savefig set to True, returns None.

TYPE: alt.ConcatChart

group_gender_age : Series, The total number of patients grouped by gender and binned age.

Source code in eds_scikit/plot/age_pyramid.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def plot_age_pyramid(
    person: DataFrame,
    datetime_ref: datetime = None,
    return_array: bool = False,
) -> Tuple[alt.ConcatChart, Series]:
    """Plot an age pyramid from a 'person' pandas DataFrame.

    Parameters
    ----------
    person : pd.DataFrame (ks.DataFrame not supported),
        The person table. Must have the following columns:
        - `birth_datetime`, dtype : datetime or str
        - `person_id`, dtype : any
        - `gender_source_value`, dtype : str, {'m', 'f'}

    datetime_ref : Union[datetime, str], default None
        The reference date to compute population age from.
        If a string, it searches for a column with the same name in the person table: each patient has his own datetime reference.
        If a datetime, the reference datetime is the same for all patients.
        If set to None, datetime.today() will be used instead.

    filename : str, default None
        The path to save figure at.

    savefig : bool, default False
        If set to True, filename must be set.
        The plot will be saved at the specified filename.

    return_array : bool, default False
        If set to True, return chart and its pd.Dataframe representation.

    Returns
    -------
    chart : alt.ConcatChart,
        If savefig set to True, returns None.

    group_gender_age : Series,
        The total number of patients grouped by gender and binned age.
    """
    check_columns(person, ["person_id", "birth_datetime", "gender_source_value"])

    datetime_ref_raw = copy(datetime_ref)

    if datetime_ref is None:
        datetime_ref = datetime.today()
    elif isinstance(datetime_ref, datetime):
        datetime_ref = pd.to_datetime(datetime_ref)
    elif isinstance(datetime_ref, str):
        # A string type for datetime_ref could be either
        # a column name or a datetime in string format.
        if datetime_ref in person.columns:
            datetime_ref = person[datetime_ref]
        else:
            datetime_ref = pd.to_datetime(
                datetime_ref, errors="coerce"
            )  # In case of error, will return NaT
            if pd.isnull(datetime_ref):
                raise ValueError(
                    f"`datetime_ref` must either be a column name or parseable date, "
                    f"got string '{datetime_ref_raw}'"
                )
    else:
        raise TypeError(
            f"`datetime_ref` must be either None, a parseable string date"
            f", a column name or a datetime. Got type: {type(datetime_ref)}, {datetime_ref}"
        )

    cols_to_keep = ["person_id", "birth_datetime", "gender_source_value"]
    person_ = bd.to_pandas(person[cols_to_keep])

    person_["age"] = (datetime_ref - person_["birth_datetime"]).dt.total_seconds()
    person_["age"] /= 365 * 24 * 3600

    # Remove outliers
    mask_age_inliners = (person_["age"] > 0) & (person_["age"] < 125)
    n_outliers = (~mask_age_inliners).sum()
    if n_outliers > 0:
        perc_outliers = 100 * n_outliers / person_.shape[0]
        logger.warning(
            f"{n_outliers} ({perc_outliers:.4f}%) individuals' "
            "age is out of the (0, 125) interval, we skip them."
        )
    person_ = person_.loc[mask_age_inliners]

    # Aggregate rare age categories
    mask_rare_age_agg = person_["age"] > 90
    person_.loc[mask_rare_age_agg, "age"] = 99

    bins = np.arange(0, 100, 10)
    labels = [f"{left}-{right}" for left, right in zip(bins[:-1], bins[1:])]
    person_["age_bins"] = pd.cut(person_["age"], bins=bins, labels=labels)

    person_ = person_.loc[person_["gender_source_value"].isin(["m", "f"])]
    group_gender_age = person_.groupby(["gender_source_value", "age_bins"])[
        "person_id"
    ].count()

    male = group_gender_age["m"].reset_index()
    female = group_gender_age["f"].reset_index()

    left = (
        alt.Chart(male)
        .mark_bar()
        .encode(
            y=alt.Y("age_bins", axis=None, sort=alt.SortOrder("descending")),
            x=alt.X("person_id", sort=alt.SortOrder("descending")),
        )
        .properties(title="Male")
    )

    right = (
        alt.Chart(female)
        .mark_bar(color="coral")
        .encode(
            y=alt.Y("age_bins", axis=None, sort=alt.SortOrder("descending")),
            x=alt.X("person_id", title="N"),
        )
        .properties(title="Female")
    )

    middle = (
        alt.Chart(male)
        .mark_text()
        .encode(
            y=alt.Y("age_bins", axis=None, sort=alt.SortOrder("descending")),
            text=alt.Text("age_bins"),
        )
    )

    chart = alt.concat(left, middle, right, spacing=5)

    if return_array:
        return group_gender_age

    return chart
Back to top