AWS Athena with Python

Keywords: Athena, Python, pyathena

在 Python 的世界里有一个标准叫做 DB API 2.0 的标准. 无论底层是什么 SQL 数据库 (前提要是 SQL 数据库), 只要你的库遵循这个标准, 那么就可以用 connect.execute(sql_statement) 这样的语法返回一个 iterable 的 cursor 对象, 返回的每条记录是以 namedtuple 的形式存在的. 而 Python 中的生产级 SQL 库 sqlalchemy 也能对遵循 DB API 2.0 的库有着良好的支持.

Python 社区对 Athena 的 DB API 2.0 实现的库是 https://pypi.org/project/pyathena/. 本质上 Athena 是将数据存在 S3 bucket 中, 而 pyathena 是通过实现一个 wrapper, 以实现 DB API 2.0 标准. 如果想要用 python 操作 Athena, 建议参照 pyathena 的文档, 配合 sqlalchemypandas 一起使用, 体验最好.

requirements.txt Dependency

# test dependencies for pyathena
pyathena
PyAthena[Pandas]
PyAthena[SQLAlchemy]
pandas
sqlalchemy
s3pathlib
smart_open

prepare_data.py data faker

# -*- coding: utf-8 -*-

"""

"""

import numpy as np
import pandas as pd
from s3pathlib import S3Path

# define s3 object path
class Config:
    bucket = "aws-data-lab-sanhe-for-everything-us-east-2"
    prefix = "poc/2022-02-04-aws-athena-with-python"

s3path = S3Path(
    Config.bucket,
    Config.prefix,
    "events",
    "data.csv",
)
print(f"preview data at {s3path.console_url}")

# generate dummy data
n_rows = 1000
df = pd.DataFrame()
df["id"] = range(1, n_rows+1)
df["time"] = pd.date_range(start="2000-01-01", end="2000-03-31", periods=n_rows)
df["category"] = np.random.randint(1, 1+3, size=n_rows)
df["value"] = np.random.randint(1, 1+100, size=n_rows)

# write csv to s3
with s3path.open("w") as f:
    df.to_csv(f, index=False)

query_data.py pyathena usage example

# -*- coding: utf-8 -*-

from s3pathlib import S3Path
from pyathena import connect
import pandas as pd

# define the s3 foldter to store result
s3path_athena_result = S3Path(
    "aws-data-lab-sanhe-for-everything-us-east-2",
    "athena/results/"
)

# define connection, use AWS CLI named profile for authentication
conn = connect(
    s3_staging_dir=s3path_athena_result.uri,
    profile_name="aws_data_lab_sanhe",
    region_name="us-east-2",
)

# define the SQL statement, use ${database}.${table} as t to specify the table
sql = """
SELECT 
    t.category,
    AVG(t.value) as average_value  
FROM poc.events t
GROUP BY t.category
ORDER BY t.category
"""

# execute the SQL query, load result to pandas DataFrame
df = pd.read_sql_query(sql, conn)
print(df)