Amazon Comprehend - Custom Document Classification Example¶
这个 Notebook 演示了如何用 comprehend custom document classification 的功能构建一个文档分类模型. 这里我们用到的数据集是 Kaggle 比赛上的一个 BBC 新闻分类数据. 一共 2225 个文档, 5 个类. 每个文档只属于其中的一类.
一些准备工作¶
import 需要的包
定义 AWS credential, 用来储存数据的本地路径, S3 文件夹, IAM Role 等.
[44]:
# Standard Library
import typing as T
import os
import json
import random
import zipfile
import dataclasses
# Third Party Library
from rich import print as rprint
from pathlib_mate import Path
from s3pathlib import S3Path, context
from boto_session_manager import BotoSesManager, AwsServiceEnum
import polars as pl
bsm = BotoSesManager(profile_name="aws_data_lab_sanhe_us_east_2")
context.attach_boto_session(bsm.boto_ses)
ch_client = bsm.get_client(AwsServiceEnum.Comprehend)
dir_here = Path(os.getcwd())
dir_docs = dir_here / "docs"
path_bbc_zip = dir_here / "bbc.zip"
dir_bbc = dir_here / "bbc"
path_manifest_csv = dir_here / "manifest.csv"
s3dir_root = S3Path.from_s3_uri("s3://aws-data-lab-sanhe-for-everything-us-east-2/poc/2022-11-05-document-classification-example/").to_dir()
s3path_manifest_csv = s3dir_root / "manifest.csv"
s3dir_test_data = s3dir_root.joinpath("test-data").to_dir()
s3dir_predict = s3dir_root.joinpath("predict").to_dir()
print(f"preview s3dir_root: {s3dir_root.console_url}")
preview s3dir_root: https://console.aws.amazon.com/s3/buckets/aws-data-lab-sanhe-for-everything-us-east-2?prefix=poc/2022-11-05-document-classification-example/
Unzip Dataset¶
将 Kaggle 上下载下来的 Zip 解压到本地, 然后再本地浏览一下, 对数据有个基本的概念.
[9]:
if not dir_bbc.exists():
with zipfile.ZipFile(path_bbc_zip.abspath) as f:
f.extractall(path=dir_here.abspath)
Split Train and Test¶
将数据按 7:3 的比例分为 training 和 testing.
[25]:
@dataclasses.dataclass
class Document:
path: Path = dataclasses.field()
label: str = dataclasses.field()
doc_list: T.List[Document] = list()
for dir in dir_bbc.select_dir(recursive=False):
for path in dir.select_by_ext(".txt"):
doc = Document(path=path, label=dir.basename)
doc_list.append(doc)
n_doc = len(doc_list)
n_train = int(0.7 * n_doc)
n_test = n_doc - n_train
train_indices: T.Set[int] = set(random.sample(list(range(n_doc)), n_train))
train_doc_list: T.List[Document] = list()
test_doc_list: T.List[Document] = list()
for ind, doc in enumerate(doc_list):
if ind in train_indices:
train_doc_list.append(doc)
else:
test_doc_list.append(doc)
print(f"total docs = {n_doc}")
print(f"train docs = {n_train}")
print(f"test docs = {n_test}")
total docs = 2225
train docs = 1557
test docs = 668
将数据整理成 Comprehend Manifest CSV file 的格式¶
Comprehend 只认识这 Manifest 这种格式. 虽然文档中提到了 JSON 格式, 但是这个格式是假设你用 Ground Truth 来 Label 的情况. 我们自己构建还是用 CSV 格式. 简单来说就是要把文档变成一行, 把换行符都变成空格. 详细文档如下.
[36]:
rows: T.List[T.Dict[str, T.Any]] = list()
for doc in train_doc_list:
content = doc.path.read_text(encoding="utf-8", errors="ignore")
line_content = content.replace("\n", " ").replace("\r", " ")
rows.append(dict(
label=doc.label,
content=line_content,
))
df = pl.DataFrame(rows)
df.write_csv(path_manifest_csv.abspath, has_header=False)
s3path_manifest_csv.upload_file(path_manifest_csv.abspath, overwrite=True)
Create Document Classifier¶
然后调用 API 训练一个模型, 通常耗时在半小时左右.
[40]:
classifier_name = "MyClassifier"
classifier_arn = f"arn:aws:comprehend:{bsm.aws_region}:{bsm.aws_account_id}:document-classifier/{classifier_name}"
create_response = ch_client.create_document_classifier(
InputDataConfig=dict(
DataFormat="COMPREHEND_CSV",
S3Uri=s3path_manifest_csv.uri,
),
DataAccessRoleArn="arn:aws:iam::669508176277:role/sanhe-comprehend-admin-access",
DocumentClassifierName="MyClassifier",
LanguageCode="en",
Mode="MULTI_CLASS",
)
rprint(create_response)
[43]:
print("upload test data to s3 ...")
print(f" preview s3dir_test_data: {s3dir_test_data.console_url}")
for doc in test_doc_list:
s3path = s3dir_test_data.joinpath(f"{doc.label}-{doc.path.basename}")
s3path.upload_file(doc.path.abspath)
print(" done")
upload test data to s3 ...
preview s3dir_test_data: https://console.aws.amazon.com/s3/buckets/aws-data-lab-sanhe-for-everything-us-east-2?prefix=poc/2022-11-05-document-classification-example/test-data/
done
Start Document Classification Job¶
然后就可以用训练好的模型做预测了. 其中 API 格式为指定一个 S3 folder 里面每个文件都代表一个文档.
[45]:
start_response = ch_client.start_document_classification_job(
InputDataConfig=dict(
S3Uri=s3dir_test_data.uri,
InputFormat="ONE_DOC_PER_FILE",
),
OutputDataConfig=dict(
S3Uri=s3dir_predict.uri,
),
DataAccessRoleArn="arn:aws:iam::669508176277:role/sanhe-comprehend-admin-access",
DocumentClassifierArn=classifier_arn,
)
print(start_response)
{'JobId': 'c9899916d3ee54ea34a8bd9bb51ba1ca', 'JobArn': 'arn:aws:comprehend:us-east-2:669508176277:document-classification-job/c9899916d3ee54ea34a8bd9bb51ba1ca', 'JobStatus': 'SUBMITTED', 'ResponseMetadata': {'RequestId': '6e0bc2a5-c28a-4dbc-9a76-610627c0f3f2', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '6e0bc2a5-c28a-4dbc-9a76-610627c0f3f2', 'content-type': 'application/x-amz-json-1.1', 'content-length': '182', 'date': 'Sat, 05 Nov 2022 19:10:35 GMT'}, 'RetryAttempts': 0}}
Evaluate Prediction Result¶
结果出来以后我们就将其跟原来的 Label 做对比, 计算准确度. Comprehend Output 的格式是一个 .tar 文件. 你需要将其解包后就能看到结果是一个 .jsonl 文件, 也就是一行就是一个 JSON. 我们可以看到准确度还是很不错的, 达到了 98.35%.
[56]:
lines = dir_here.joinpath("predictions.json").read_text().strip().split("\n")
results = [
json.loads(line.strip())
for line in lines
]
n_test = len(results)
n_good = 0
for record in results:
file = record["File"]
true_label = file.split("-")[0]
predict_label = list(sorted(
record["Classes"],
key=lambda x: x["Score"],
reverse=True
))[0]["Name"]
if true_label == predict_label:
n_good += 1
print(f"accuracy = {n_good / n_test}")
accuracy = 0.9835329341317365
[ ]: