{ "cells": [ { "cell_type": "markdown", "source": [ "# Amazon Comprehend - Custom Document Classification Example\n", "\n", "这个 Notebook 演示了如何用 comprehend custom document classification 的功能构建一个文档分类模型. 这里我们用到的数据集是 Kaggle 比赛上的一个 BBC 新闻分类数据. 一共 2225 个文档, 5 个类. 每个文档只属于其中的一类.\n", "\n", "- Data Source: https://www.kaggle.com/datasets/shivamkushwaha/bbc-full-text-document-classification" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "markdown", "source": [ "# 一些准备工作\n", "\n", "- import 需要的包\n", "- 定义 AWS credential, 用来储存数据的本地路径, S3 文件夹, IAM Role 等." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 44, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "preview s3dir_root: https://console.aws.amazon.com/s3/buckets/aws-data-lab-sanhe-for-everything-us-east-2?prefix=poc/2022-11-05-document-classification-example/\n" ] } ], "source": [ "# Standard Library\n", "import typing as T\n", "import os\n", "import json\n", "import random\n", "import zipfile\n", "import dataclasses\n", "\n", "\n", "# Third Party Library\n", "from rich import print as rprint\n", "from pathlib_mate import Path\n", "from s3pathlib import S3Path, context\n", "from boto_session_manager import BotoSesManager, AwsServiceEnum\n", "\n", "import polars as pl\n", "\n", "bsm = BotoSesManager(profile_name=\"aws_data_lab_sanhe_us_east_2\")\n", "context.attach_boto_session(bsm.boto_ses)\n", "\n", "ch_client = bsm.get_client(AwsServiceEnum.Comprehend)\n", "\n", "dir_here = Path(os.getcwd())\n", "dir_docs = dir_here / \"docs\"\n", "path_bbc_zip = dir_here / \"bbc.zip\"\n", "dir_bbc = dir_here / \"bbc\"\n", "path_manifest_csv = dir_here / \"manifest.csv\"\n", "\n", "s3dir_root = S3Path.from_s3_uri(\"s3://aws-data-lab-sanhe-for-everything-us-east-2/poc/2022-11-05-document-classification-example/\").to_dir()\n", "s3path_manifest_csv = s3dir_root / \"manifest.csv\"\n", "s3dir_test_data = s3dir_root.joinpath(\"test-data\").to_dir()\n", "s3dir_predict = s3dir_root.joinpath(\"predict\").to_dir()\n", "print(f\"preview s3dir_root: {s3dir_root.console_url}\")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "# Unzip Dataset\n", "\n", "将 Kaggle 上下载下来的 Zip 解压到本地, 然后再本地浏览一下, 对数据有个基本的概念." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 9, "outputs": [], "source": [ "if not dir_bbc.exists():\n", " with zipfile.ZipFile(path_bbc_zip.abspath) as f:\n", " f.extractall(path=dir_here.abspath)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "# Split Train and Test\n", "\n", "将数据按 7:3 的比例分为 training 和 testing.\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 25, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total docs = 2225\n", "train docs = 1557\n", "test docs = 668\n" ] } ], "source": [ "@dataclasses.dataclass\n", "class Document:\n", " path: Path = dataclasses.field()\n", " label: str = dataclasses.field()\n", "\n", "\n", "doc_list: T.List[Document] = list()\n", "for dir in dir_bbc.select_dir(recursive=False):\n", " for path in dir.select_by_ext(\".txt\"):\n", " doc = Document(path=path, label=dir.basename)\n", " doc_list.append(doc)\n", "\n", "n_doc = len(doc_list)\n", "n_train = int(0.7 * n_doc)\n", "n_test = n_doc - n_train\n", "\n", "train_indices: T.Set[int] = set(random.sample(list(range(n_doc)), n_train))\n", "train_doc_list: T.List[Document] = list()\n", "test_doc_list: T.List[Document] = list()\n", "for ind, doc in enumerate(doc_list):\n", " if ind in train_indices:\n", " train_doc_list.append(doc)\n", " else:\n", " test_doc_list.append(doc)\n", "\n", "print(f\"total docs = {n_doc}\")\n", "print(f\"train docs = {n_train}\")\n", "print(f\"test docs = {n_test}\")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "# 将数据整理成 Comprehend Manifest CSV file 的格式\n", "\n", "Comprehend 只认识这 Manifest 这种格式. 虽然文档中提到了 JSON 格式, 但是这个格式是假设你用 Ground Truth 来 Label 的情况. 我们自己构建还是用 CSV 格式. 简单来说就是要把文档变成一行, 把换行符都变成空格. 详细文档如下.\n", "\n", "- https://docs.aws.amazon.com/comprehend/latest/dg/how-document-classification.html\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 36, "outputs": [], "source": [ "rows: T.List[T.Dict[str, T.Any]] = list()\n", "for doc in train_doc_list:\n", " content = doc.path.read_text(encoding=\"utf-8\", errors=\"ignore\")\n", " line_content = content.replace(\"\\n\", \" \").replace(\"\\r\", \" \")\n", " rows.append(dict(\n", " label=doc.label,\n", " content=line_content,\n", " ))\n", "df = pl.DataFrame(rows)\n", "df.write_csv(path_manifest_csv.abspath, has_header=False)\n", "s3path_manifest_csv.upload_file(path_manifest_csv.abspath, overwrite=True)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "# Create Document Classifier\n", "\n", "然后调用 API 训练一个模型, 通常耗时在半小时左右." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 40, "outputs": [], "source": [ "classifier_name = \"MyClassifier\"\n", "classifier_arn = f\"arn:aws:comprehend:{bsm.aws_region}:{bsm.aws_account_id}:document-classifier/{classifier_name}\"\n", "create_response = ch_client.create_document_classifier(\n", " InputDataConfig=dict(\n", " DataFormat=\"COMPREHEND_CSV\",\n", " S3Uri=s3path_manifest_csv.uri,\n", " ),\n", " DataAccessRoleArn=\"arn:aws:iam::669508176277:role/sanhe-comprehend-admin-access\",\n", " DocumentClassifierName=\"MyClassifier\",\n", " LanguageCode=\"en\",\n", " Mode=\"MULTI_CLASS\",\n", ")\n", "rprint(create_response)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 43, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "upload test data to s3 ...\n", " preview s3dir_test_data: https://console.aws.amazon.com/s3/buckets/aws-data-lab-sanhe-for-everything-us-east-2?prefix=poc/2022-11-05-document-classification-example/test-data/\n", " done\n" ] } ], "source": [ "print(\"upload test data to s3 ...\")\n", "print(f\" preview s3dir_test_data: {s3dir_test_data.console_url}\")\n", "for doc in test_doc_list:\n", " s3path = s3dir_test_data.joinpath(f\"{doc.label}-{doc.path.basename}\")\n", " s3path.upload_file(doc.path.abspath)\n", "print(\" done\")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "# Start Document Classification Job\n", "\n", "然后就可以用训练好的模型做预测了. 其中 API 格式为指定一个 S3 folder 里面每个文件都代表一个文档." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 45, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'JobId': 'c9899916d3ee54ea34a8bd9bb51ba1ca', 'JobArn': 'arn:aws:comprehend:us-east-2:669508176277:document-classification-job/c9899916d3ee54ea34a8bd9bb51ba1ca', 'JobStatus': 'SUBMITTED', 'ResponseMetadata': {'RequestId': '6e0bc2a5-c28a-4dbc-9a76-610627c0f3f2', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '6e0bc2a5-c28a-4dbc-9a76-610627c0f3f2', 'content-type': 'application/x-amz-json-1.1', 'content-length': '182', 'date': 'Sat, 05 Nov 2022 19:10:35 GMT'}, 'RetryAttempts': 0}}\n" ] } ], "source": [ "start_response = ch_client.start_document_classification_job(\n", " InputDataConfig=dict(\n", " S3Uri=s3dir_test_data.uri,\n", " InputFormat=\"ONE_DOC_PER_FILE\",\n", " ),\n", " OutputDataConfig=dict(\n", " S3Uri=s3dir_predict.uri,\n", " ),\n", " DataAccessRoleArn=\"arn:aws:iam::669508176277:role/sanhe-comprehend-admin-access\",\n", " DocumentClassifierArn=classifier_arn,\n", ")\n", "print(start_response)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "# Evaluate Prediction Result\n", "\n", "结果出来以后我们就将其跟原来的 Label 做对比, 计算准确度. Comprehend Output 的格式是一个 .tar 文件. 你需要将其解包后就能看到结果是一个 .jsonl 文件, 也就是一行就是一个 JSON. 我们可以看到准确度还是很不错的, 达到了 98.35%." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 56, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy = 0.9835329341317365\n" ] } ], "source": [ "lines = dir_here.joinpath(\"predictions.json\").read_text().strip().split(\"\\n\")\n", "results = [\n", " json.loads(line.strip())\n", " for line in lines\n", "]\n", "n_test = len(results)\n", "n_good = 0\n", "for record in results:\n", " file = record[\"File\"]\n", " true_label = file.split(\"-\")[0]\n", " predict_label = list(sorted(\n", " record[\"Classes\"],\n", " key=lambda x: x[\"Score\"],\n", " reverse=True\n", " ))[0][\"Name\"]\n", " if true_label == predict_label:\n", " n_good += 1\n", "print(f\"accuracy = {n_good / n_test}\")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }