Skip to content

eds_scikit.plot.data_quality

plot_age_pyramid

plot_age_pyramid(person: DataFrame, datetime_ref: datetime = None, filename: str = None, savefig: bool = False, return_vector: bool = False) -> Tuple[alt.Chart, Series]

Plot an age pyramid from a 'person' pandas DataFrame.

PARAMETER DESCRIPTION
person

The person table. Must have the following columns: - birth_datetime, dtype : datetime or str - person_id, dtype : any - gender_source_value, dtype : str, {'m', 'f'}

TYPE: pd.DataFrame (ks.DataFrame not supported),

datetime_ref : datetime, The reference date to compute population age from. If set to None, datetime.today() will be used instead.

savefig : bool, If set to True, filename must be set. The plot will be saved at the specified filename.

filename : Optional[str], The path to save figure at.

RETURNS DESCRIPTION
chart

If savefig set to True, returns None.

TYPE: alt.Chart

group_gender_age : Series, The total number of patients grouped by gender and binned age.

Source code in eds_scikit/plot/data_quality.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def plot_age_pyramid(
    person: DataFrame,
    datetime_ref: datetime = None,
    filename: str = None,
    savefig: bool = False,
    return_vector: bool = False,
) -> Tuple[alt.Chart, Series]:
    """Plot an age pyramid from a 'person' pandas DataFrame.

    Parameters
    ----------
    person : pd.DataFrame (ks.DataFrame not supported),
        The person table. Must have the following columns:
        - `birth_datetime`, dtype : datetime or str
        - `person_id`, dtype : any
        - `gender_source_value`, dtype : str, {'m', 'f'}

    datetime_ref : datetime,
        The reference date to compute population age from.
        If set to None, datetime.today() will be used instead.

    savefig : bool,
        If set to True, filename must be set.
        The plot will be saved at the specified filename.

    filename : Optional[str],
        The path to save figure at.

    Returns
    -------
    chart : alt.Chart,
        If savefig set to True, returns None.

    group_gender_age : Series,
        The total number of patients grouped by gender and binned age.
    """
    check_columns(person, ["person_id", "birth_datetime", "gender_source_value"])

    if savefig:
        if filename is None:
            raise ValueError("You have to set a filename")
        if not isinstance(filename, str):
            raise ValueError(f"'filename' type must be str, got {type(filename)}")

    person_ = person.copy()

    if datetime_ref is None:
        today = datetime.today()
    else:
        today = pd.to_datetime(datetime_ref)

    # TODO: replace with from ..utils.datetime_helpers.substract_datetime
    deltas = today - person_["birth_datetime"]
    if bd.is_pandas(person_):
        deltas = deltas.dt.total_seconds()

    person_["age"] = deltas / (365 * 24 * 3600)
    person_ = person_.query("age > 0.0")

    bins = np.arange(0, 100, 10)
    labels = [f"{left}-{right}" for left, right in zip(bins[:-1], bins[1:])]
    person_["age_bins"] = bd.cut(person_["age"], bins=bins, labels=labels)

    person_["age_bins"] = (
        person_["age_bins"].astype(str).str.lower().str.replace("nan", "90+")
    )

    person_ = person_.loc[person_["gender_source_value"].isin(["m", "f"])]
    group_gender_age = person_.groupby(["gender_source_value", "age_bins"])[
        "person_id"
    ].count()

    # Convert to pandas to ease plotting.
    # Since we have aggregated the data, this operation won't crash.
    group_gender_age = bd.to_pandas(group_gender_age)

    male = group_gender_age["m"].reset_index()
    female = group_gender_age["f"].reset_index()

    left = (
        alt.Chart(male)
        .mark_bar()
        .encode(
            y=alt.Y("age_bins", axis=None, sort=alt.SortOrder("descending")),
            x=alt.X("person_id", sort=alt.SortOrder("descending")),
        )
        .properties(title="Male")
    )

    right = (
        alt.Chart(female)
        .mark_bar(color="coral")
        .encode(
            y=alt.Y("age_bins", axis=None, sort=alt.SortOrder("descending")),
            x=alt.X("person_id", title="N"),
        )
        .properties(title="Female")
    )

    middle = (
        alt.Chart(male)
        .mark_text()
        .encode(
            y=alt.Text("age_bins", axis=None, sort=alt.SortOrder("descending")),
            text=alt.Y("age_bins"),
        )
    )

    chart = alt.concat(left, middle, right, spacing=5)

    if savefig:
        chart.save(filename)
        if return_vector:
            return group_gender_age

    if return_vector:
        return chart, group_gender_age

    return chart