Customised plots with plotly graph objects: Example Notebook

Customised plots with plotly graph objects: Example Notebook#

import pandas as pd

# Set default template
import plotly.io as pio

pio.templates.default = "plotly_dark"

# Use plot methods on DataFrame directly
pd.options.plotting.backend = "plotly"
# High-level approach
import plotly.express as px

# Low-level approach using graph objects
import plotly.graph_objects as go

Get life expectancy by country, continent, and year#

all_countries = (
    px.data.gapminder()[["country", "continent", "year", "lifeExp"]]
    .rename(columns={"lifeExp": "life_expectancy"})
    .query("continent in ['Europe', 'Americas']")
)
all_countries
country continent year life_expectancy
12 Albania Europe 1952 55.230
13 Albania Europe 1957 59.280
14 Albania Europe 1962 64.820
15 Albania Europe 1967 66.220
16 Albania Europe 1972 67.690
... ... ... ... ...
1639 Venezuela Americas 1987 70.190
1640 Venezuela Americas 1992 71.150
1641 Venezuela Americas 1997 72.146
1642 Venezuela Americas 2002 72.766
1643 Venezuela Americas 2007 73.747

660 rows × 4 columns

Starting point#

fig = px.line(
    all_countries,
    x="year",
    y="life_expectancy",
    labels={"life_expectancy": "Life expectancy"},
    color="country",
    facet_col="continent",
)
fig.update_layout(showlegend=False)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))

Low-level approach using graph objects#

Using plotly.graph_objects, simple plots take longer but you can customize every aspect of the plot.

from plotly.subplots import make_subplots
country_names = all_countries["country"].unique()
traces = {"Europe": [], "Americas": []}
for name in country_names:
    if name not in ["Canada", "Poland"]:
        data = all_countries.query(f"country == '{name}'")
        traces[data["continent"].iloc[0]].append(
            go.Scatter(
                x=data["year"],
                y=data["life_expectancy"],
                name=name,
                mode="lines",
                line={"color": "darkgray"},
            ),
        )
fig = make_subplots(rows=1, cols=2, subplot_titles=list(traces.keys()))
fig.update_layout(showlegend=False)
fig.add_traces(traces["Europe"], rows=1, cols=1)
fig.add_traces(traces["Americas"], rows=1, cols=2)
fig.update_xaxes(matches="x")
fig.update_yaxes(matches="y")
color = "red"

for i, name in enumerate(["Poland", "Canada"]):
    data = all_countries.query(f"country == '{name}'")
    fig.add_trace(
        go.Scatter(
            x=data["year"],
            y=data["life_expectancy"],
            name=name,
            mode="lines",
            line={"color": color, "width": 5},
        ),
        row=1,
        col=1 + i,
    )

    fig.add_annotation(
        x=1967,
        y=77,
        text=f"<b>{name}<b>",
        font={"size": 14, "color": color},
        showarrow=False,
        row=1,
        col=1 + i,
    )
fig.show()