from great_tables import GT
from great_tables.data import countrypops
import polars as pl
import polars.selectors as cs
# Get vectors of 2-letter country codes for each region of Oceania
oceania = {
    "Australasia": ["AU", "NZ"],
    "Melanesia": ["NC", "PG", "SB", "VU"],
    "Micronesia": ["FM", "GU", "KI", "MH", "MP", "NR", "PW"],
    "Polynesia": ["PF", "WS", "TO", "TV"],
}
# Create a dictionary mapping country to region (e.g. AU -> Australasia)
country_to_region = {
    country: region for region, countries in oceania.items() for country in countries
}
wide_pops = (
    pl.from_pandas(countrypops)
    .filter(
        pl.col("country_code_2").is_in(list(country_to_region))
        & pl.col("year").is_in([2000, 2010, 2020])
    )
    .with_columns(pl.col("country_code_2").replace(country_to_region).alias("region"))
    .pivot(index=["country_name", "region"], columns="year", values="population")
    .sort("2020", descending=True)
)
(
    GT(wide_pops, rowname_col="country_name", groupname_col="region")
    .tab_header(title="Populations of Oceania's Countries in 2000, 2010, and 2020")
    .tab_spanner(label="Total Population", columns=cs.all())
    .fmt_integer()
)