Amazon Comprehend - Custom Document Classification Example

这个 Notebook 演示了如何用 comprehend custom document classification 的功能构建一个文档分类模型. 这里我们用到的数据集是 Kaggle 比赛上的一个 BBC 新闻分类数据. 一共 2225 个文档, 5 个类. 每个文档只属于其中的一类.

一些准备工作

  • import 需要的包

  • 定义 AWS credential, 用来储存数据的本地路径, S3 文件夹, IAM Role 等.

[44]:
# Standard Library
import typing as T
import os
import json
import random
import zipfile
import dataclasses


# Third Party Library
from rich import print as rprint
from pathlib_mate import Path
from s3pathlib import S3Path, context
from boto_session_manager import BotoSesManager, AwsServiceEnum

import polars as pl

bsm = BotoSesManager(profile_name="aws_data_lab_sanhe_us_east_2")
context.attach_boto_session(bsm.boto_ses)

ch_client = bsm.get_client(AwsServiceEnum.Comprehend)

dir_here = Path(os.getcwd())
dir_docs = dir_here / "docs"
path_bbc_zip = dir_here / "bbc.zip"
dir_bbc = dir_here / "bbc"
path_manifest_csv = dir_here / "manifest.csv"

s3dir_root = S3Path.from_s3_uri("s3://aws-data-lab-sanhe-for-everything-us-east-2/poc/2022-11-05-document-classification-example/").to_dir()
s3path_manifest_csv = s3dir_root / "manifest.csv"
s3dir_test_data = s3dir_root.joinpath("test-data").to_dir()
s3dir_predict = s3dir_root.joinpath("predict").to_dir()
print(f"preview s3dir_root: {s3dir_root.console_url}")
preview s3dir_root: https://console.aws.amazon.com/s3/buckets/aws-data-lab-sanhe-for-everything-us-east-2?prefix=poc/2022-11-05-document-classification-example/

Unzip Dataset

将 Kaggle 上下载下来的 Zip 解压到本地, 然后再本地浏览一下, 对数据有个基本的概念.

[9]:
if not dir_bbc.exists():
    with zipfile.ZipFile(path_bbc_zip.abspath) as f:
        f.extractall(path=dir_here.abspath)

Split Train and Test

将数据按 7:3 的比例分为 training 和 testing.

[25]:
@dataclasses.dataclass
class Document:
    path: Path = dataclasses.field()
    label: str = dataclasses.field()


doc_list: T.List[Document] = list()
for dir in dir_bbc.select_dir(recursive=False):
    for path in dir.select_by_ext(".txt"):
        doc = Document(path=path, label=dir.basename)
        doc_list.append(doc)

n_doc = len(doc_list)
n_train = int(0.7 * n_doc)
n_test = n_doc - n_train

train_indices: T.Set[int] = set(random.sample(list(range(n_doc)), n_train))
train_doc_list: T.List[Document] = list()
test_doc_list: T.List[Document] = list()
for ind, doc in enumerate(doc_list):
    if ind in train_indices:
        train_doc_list.append(doc)
    else:
        test_doc_list.append(doc)

print(f"total docs = {n_doc}")
print(f"train docs = {n_train}")
print(f"test docs = {n_test}")
total docs = 2225
train docs = 1557
test docs = 668

将数据整理成 Comprehend Manifest CSV file 的格式

Comprehend 只认识这 Manifest 这种格式. 虽然文档中提到了 JSON 格式, 但是这个格式是假设你用 Ground Truth 来 Label 的情况. 我们自己构建还是用 CSV 格式. 简单来说就是要把文档变成一行, 把换行符都变成空格. 详细文档如下.

[36]:
rows: T.List[T.Dict[str, T.Any]] = list()
for doc in train_doc_list:
    content = doc.path.read_text(encoding="utf-8", errors="ignore")
    line_content = content.replace("\n", " ").replace("\r", " ")
    rows.append(dict(
        label=doc.label,
        content=line_content,
    ))
df = pl.DataFrame(rows)
df.write_csv(path_manifest_csv.abspath, has_header=False)
s3path_manifest_csv.upload_file(path_manifest_csv.abspath, overwrite=True)

Create Document Classifier

然后调用 API 训练一个模型, 通常耗时在半小时左右.

[40]:
classifier_name = "MyClassifier"
classifier_arn = f"arn:aws:comprehend:{bsm.aws_region}:{bsm.aws_account_id}:document-classifier/{classifier_name}"
create_response = ch_client.create_document_classifier(
    InputDataConfig=dict(
        DataFormat="COMPREHEND_CSV",
        S3Uri=s3path_manifest_csv.uri,
    ),
    DataAccessRoleArn="arn:aws:iam::669508176277:role/sanhe-comprehend-admin-access",
    DocumentClassifierName="MyClassifier",
    LanguageCode="en",
    Mode="MULTI_CLASS",
)
rprint(create_response)
[43]:
print("upload test data to s3 ...")
print(f"  preview s3dir_test_data: {s3dir_test_data.console_url}")
for doc in test_doc_list:
    s3path = s3dir_test_data.joinpath(f"{doc.label}-{doc.path.basename}")
    s3path.upload_file(doc.path.abspath)
print("  done")
upload test data to s3 ...
  preview s3dir_test_data: https://console.aws.amazon.com/s3/buckets/aws-data-lab-sanhe-for-everything-us-east-2?prefix=poc/2022-11-05-document-classification-example/test-data/
  done

Start Document Classification Job

然后就可以用训练好的模型做预测了. 其中 API 格式为指定一个 S3 folder 里面每个文件都代表一个文档.

[45]:
start_response = ch_client.start_document_classification_job(
    InputDataConfig=dict(
        S3Uri=s3dir_test_data.uri,
        InputFormat="ONE_DOC_PER_FILE",
    ),
    OutputDataConfig=dict(
        S3Uri=s3dir_predict.uri,
    ),
    DataAccessRoleArn="arn:aws:iam::669508176277:role/sanhe-comprehend-admin-access",
    DocumentClassifierArn=classifier_arn,
)
print(start_response)
{'JobId': 'c9899916d3ee54ea34a8bd9bb51ba1ca', 'JobArn': 'arn:aws:comprehend:us-east-2:669508176277:document-classification-job/c9899916d3ee54ea34a8bd9bb51ba1ca', 'JobStatus': 'SUBMITTED', 'ResponseMetadata': {'RequestId': '6e0bc2a5-c28a-4dbc-9a76-610627c0f3f2', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '6e0bc2a5-c28a-4dbc-9a76-610627c0f3f2', 'content-type': 'application/x-amz-json-1.1', 'content-length': '182', 'date': 'Sat, 05 Nov 2022 19:10:35 GMT'}, 'RetryAttempts': 0}}

Evaluate Prediction Result

结果出来以后我们就将其跟原来的 Label 做对比, 计算准确度. Comprehend Output 的格式是一个 .tar 文件. 你需要将其解包后就能看到结果是一个 .jsonl 文件, 也就是一行就是一个 JSON. 我们可以看到准确度还是很不错的, 达到了 98.35%.

[56]:
lines = dir_here.joinpath("predictions.json").read_text().strip().split("\n")
results = [
    json.loads(line.strip())
    for line in lines
]
n_test = len(results)
n_good = 0
for record in results:
    file = record["File"]
    true_label = file.split("-")[0]
    predict_label = list(sorted(
        record["Classes"],
        key=lambda x: x["Score"],
        reverse=True
    ))[0]["Name"]
    if true_label == predict_label:
        n_good += 1
print(f"accuracy = {n_good / n_test}")
accuracy = 0.9835329341317365
[ ]: