Airflow Common Patterns

首先我想讨论一下如何能高效的学习 Airflow. 因为在我的职业生涯中, 我几乎市面上所有的编排系统都使用过, 不过我比较精通的是 Airflow 和 AWS StepFunction. 如何能快速上手并达到一个比较高的水平这件事其实是有规律可循的. 我认为从入门到精通, 你必须要经历三个阶段:

  1. 快速上手, 学会编写你的第一个应用

  2. 进行较为全面的学习, 把所有常见的使用场景, 常用的主打的功能都用一遍

  3. 回头进行总结, 结合生产环境中的实践进一步深入

本文主要介绍的是 #2.

其实对于编排处理系统, 主要考虑的是三个大点, 流程图和依赖关系, 数据处理, 异常处理. 每个大点下又分很多子问题. 如果能数量掌握全部的大点和子问题, 基本上任意复杂的业务逻辑也只是它们的排列组合而已.

  1. 流程图和依赖关系
    • 顺序执行

    • 串行执行

    • 条件分支

    • 分支

    • 合并

    • 异步调用和等待

  2. 数据处理, 各个 Task 之间的数据如何互通
    • 数据传递 (在两个 Task 之间传递数据)

    • 数据共享 (相当于是一个大家都能访问的全局变量)

    • 数据存储 (外部存储)

  3. 异常处理
    • 单个任务的异常和重试

    • 多个并行执行的任务的异常处理

    • 多个串行执行的任务的异常处理

1. 简单的单步任务

直接看例子.

 1# -*- coding: utf-8 -*-
 2
 3"""
 4简单的单步任务.
 5"""
 6
 7from datetime import datetime
 8from airflow.decorators import dag, task
 9
10dag_id = "dag1"
11
12@dag(
13    dag_id=dag_id,
14    start_date=datetime(2021, 1, 1),
15    schedule="@once",
16    catchup=False,
17)
18def dag1():
19    """
20    这个例子中只有一个 Task.
21    """
22    task1_id = "task1"
23
24    @task(
25        task_id=task1_id,
26    )
27    def task1():
28        """
29        该任务随机生成一个 1 - 100 的随机数, 并将这个值和当前的事件一起写入到 S3 中.
30        """
31        # 注意, 任何跟执行任务相关, 需要 import 的包, 都需要在函数内部 import.
32        # 因为 task 函数外部的内容都属于调度器的运行环境, 而 task 函数内则是执行器的运行环境
33        # 通常情况下, 调度器只要有 Airflow 服务器自带的包就够了, 而执行器很可能需要在
34        # virtualenv 中运行, 可能需要任何包.
35        # Ref: https://airflow.apache.org/docs/apache-airflow/stable/best-practices.html#top-level-python-code
36        import json
37        import random
38        from datetime import datetime
39
40        import boto3 # 标准的开源 Airflow 并不预装 boto3, 而 AWS MWAA 预装了 boto3.
41
42        print("Start task1")
43
44        aws_account_id = boto3.client("sts").get_caller_identity()["Account"]
45        aws_region = "us-east-1"
46
47        value = random.randint(1, 100)
48        print(f"Generated value is {value}")
49        data = {
50            "value": value,
51            "datetime": datetime.now().isoformat(),
52        }
53        # 记得确保你的 MWAA 的 IAM Role 里有这个 bucket 的读写权限
54        boto3.client("s3").put_object(
55            Bucket=f"{aws_account_id}-{aws_region}-data",
56            Key=f"projects/mwaa-poc/{dag_id}/{task1_id}.output.json",
57            Body=json.dumps(data),
58            ContentType="application/json",
59        )
60
61        print("End task1")
62        # 在 Python Operator (也就是 @task 装饰器装饰的函数) 中, 返回值将会被视为 Task Output
63        # 这个返回值一定要是一个可序列化的对象, 因为 Airflow 会自动负责将这个对象序列化后以备
64        # 后续的 task 使用 (虽然这个例子中没有).
65        return "Returned by task 1"
66
67    run_task1 = task1() # 你只要调用这个 task 函数就相当于告诉 Airflow 我要运行这个 task 了.
68
69
70run_dag1 = dag1()

2. 简单的两步任务, 串行

直接看例子.

  1# -*- coding: utf-8 -*-
  2
  3"""
  4简单的两步任务, 串行.
  5"""
  6
  7from datetime import datetime
  8from airflow.decorators import dag, task
  9
 10dag_id = "dag2"
 11
 12
 13@dag(
 14    dag_id=dag_id,
 15    start_date=datetime(2021, 1, 1),
 16    schedule="@once",
 17    catchup=False,
 18)
 19def dag2():
 20    """
 21    这个例子中我们有两个 Task, 先运行 Task 1, 再运行 Task 2. 如果 Task 1 不成功, 则不运行 Task 2.
 22    我们的两个 Task 的任务都是生成一个随机数, 然后写入到 S3 中. 不过 Task 1 有 50% 的概率会失败.
 23    你可以看到如果 Task 1 失败了, 则 Task 2 不会被执行.
 24    """
 25    task1_id = "task1"
 26    task2_id = "task2"
 27
 28    @task(
 29        task_id=task1_id,
 30    )
 31    def task1():
 32        import json
 33        import random
 34        from datetime import datetime
 35
 36        import boto3
 37
 38        print("Start task1")
 39
 40        # 有 50% 的概率失败
 41        if random.randint(1, 100) <= 50:
 42            raise Exception("Randomly failed")
 43
 44        aws_account_id = boto3.client("sts").get_caller_identity()["Account"]
 45        aws_region = "us-east-1"
 46
 47        value = random.randint(1, 100)
 48        print(f"Generated value is {value}")
 49        data = {
 50            "value": value,
 51            "datetime": datetime.now().isoformat(),
 52        }
 53        boto3.client("s3").put_object(
 54            Bucket=f"{aws_account_id}-{aws_region}-data",
 55            Key=f"projects/mwaa-poc/{dag_id}/{task1_id}.output.json",
 56            Body=json.dumps(data),
 57            ContentType="application/json",
 58        )
 59
 60        print("End task1")
 61        return "Returned by task 1"
 62
 63    @task(
 64        task_id=task2_id,
 65    )
 66    def task2():
 67        import json
 68        import random
 69        from datetime import datetime
 70
 71        import boto3
 72
 73        print("Start task2")
 74
 75        aws_account_id = boto3.client("sts").get_caller_identity()["Account"]
 76        aws_region = "us-east-1"
 77
 78        value = random.randint(1, 100)
 79        print(f"Generated value is {value}")
 80        data = {
 81            "value": value,
 82            "datetime": datetime.now().isoformat(),
 83        }
 84        boto3.client("s3").put_object(
 85            Bucket=f"{aws_account_id}-{aws_region}-data",
 86            Key=f"projects/mwaa-poc/{dag_id}/{task2_id}.output.json",
 87            Body=json.dumps(data),
 88            ContentType="application/json",
 89        )
 90
 91        print("End task2")
 92        return "Returned by task 2"
 93
 94    # 这里调用了两个 task 函数, 就相当于告诉 Airflow 我要运行他们两. 默认情况下 Airflow
 95    # 认为它们是并行, 没有依赖关系, 但是如果你用 >>, << 这样的符号连接他们, 就表示他们是
 96    # 有依赖关系的.
 97    # Ref: https://airflow.apache.org/docs/apache-airflow/stable/tutorial/fundamentals.html#setting-up-dependencies
 98    run_task1 = task1()
 99    run_task2 = task2()
100
101    run_task1 >> run_task2
102
103
104run_dag2 = dag2()

3. 简单的两步任务, 并行

直接看例子.

  1# -*- coding: utf-8 -*-
  2
  3"""
  4简单的两步任务, 并行.
  5"""
  6
  7from datetime import datetime
  8from airflow.decorators import dag, task
  9
 10dag_id = "dag3"
 11
 12
 13@dag(
 14    dag_id=dag_id,
 15    start_date=datetime(2021, 1, 1),
 16    schedule="@once",
 17    catchup=False,
 18)
 19def dag3():
 20    """
 21    这个例子中我们有两个 Task, 两个 Task 并行.
 22    我们的两个 Task 的任务都是生成一个随机数, 然后写入到 S3 中. 不过 Task 1 有 50% 的概率会失败.
 23    对于并行的任务, 如果其中一个失败了, 则整个 DAG 会失败, 但是并不会阻止其他任务失败.
 24    根据实验结果, 可以看出不管 Task 1 失败与否, Task 2 都成功了.
 25
 26    而 Airflow 允许你手动的对具体的 Task 进行重试, 如果你在 Airflow UI 里可以看到 Task 1
 27    红了 (失败了), 你点进去在 Task Action 里 Run 即可对 Task 1 重运行. 不过默认是不允许
 28    已经失败的任务进行重试的. 除非你手动选择了 Ignore All Deps 选项 (它的意思是忽略所有依赖
 29    条件, 这里的依赖就包括了不允许冲运行已经失败的任务).
 30    """
 31    task1_id = "task1"
 32    task2_id = "task2"
 33
 34    @task(
 35        task_id=task1_id,
 36    )
 37    def task1():
 38        import json
 39        import random
 40        from datetime import datetime
 41
 42        import boto3
 43
 44        print("Start task1")
 45
 46        # 有 50% 的概率失败
 47        if random.randint(1, 100) <= 50:
 48            raise Exception("Randomly failed")
 49
 50        aws_account_id = boto3.client("sts").get_caller_identity()["Account"]
 51        aws_region = "us-east-1"
 52
 53        value = random.randint(1, 100)
 54        print(f"Generated value is {value}")
 55        data = {
 56            "value": value,
 57            "datetime": datetime.now().isoformat(),
 58        }
 59        boto3.client("s3").put_object(
 60            Bucket=f"{aws_account_id}-{aws_region}-data",
 61            Key=f"projects/mwaa-poc/{dag_id}/{task1_id}.output.json",
 62            Body=json.dumps(data),
 63            ContentType="application/json",
 64        )
 65
 66        print("End task1")
 67        return "Returned by task 1"
 68
 69    @task(
 70        task_id=task2_id,
 71    )
 72    def task2():
 73        import json
 74        import random
 75        from datetime import datetime
 76
 77        import boto3
 78
 79        print("Start task2")
 80
 81        aws_account_id = boto3.client("sts").get_caller_identity()["Account"]
 82        aws_region = "us-east-1"
 83
 84        value = random.randint(1, 100)
 85        print(f"Generated value is {value}")
 86        data = {
 87            "value": value,
 88            "datetime": datetime.now().isoformat(),
 89        }
 90        boto3.client("s3").put_object(
 91            Bucket=f"{aws_account_id}-{aws_region}-data",
 92            Key=f"projects/mwaa-poc/{dag_id}/{task2_id}.output.json",
 93            Body=json.dumps(data),
 94            ContentType="application/json",
 95        )
 96
 97        print("End task2")
 98        return "Returned by task 2"
 99
100    # 这里调用了两个 task 函数, 就相当于告诉 Airflow 我要运行他们两. 默认情况下 Airflow
101    # 认为它们是并行, 没有依赖关系.
102    # Ref: https://airflow.apache.org/docs/apache-airflow/stable/tutorial/fundamentals.html#setting-up-dependencies
103    run_task1 = task1()
104    run_task2 = task2()
105
106
107run_dag3 = dag3()

4. 使用第三方 Python 包

直接看例子.

 1# -*- coding: utf-8 -*-
 2
 3"""
 4在 Task 中使用第三方包, 也就是使用 Virtualenv 来运行 Task.
 5"""
 6
 7from datetime import datetime
 8from airflow.decorators import dag, task
 9
10dag_id = "dag4"
11
12
13@dag(
14    dag_id=dag_id,
15    start_date=datetime(2021, 1, 1),
16    schedule="@once",
17    catchup=False,
18)
19def dag4():
20    """ """
21    task1_id = "task1"
22
23    # @task.virtualenv 装饰器表示这个 task 需要在一个 Virtualenv 中运行.
24    # 它的底层其实是 PythonVirtualenvOperator
25    # Ref: https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/operators/python/index.html#airflow.operators.python.PythonVirtualenvOperator
26    @task.virtualenv(
27        task_id=task1_id,
28        # 如果不指定这个参数, 则默认使用 Airflow 所用的 Python 版本, 这里我们看了 MWAA 上
29        # Airflow 2.5.1 对应的 Python 版本是 3.10
30        # 如果你需要使用一个跟 Airflow 不同的 Python 解释器, 那么推荐你不要使用 task.virtualenv
31        # 而是使用 @task.external_python. 你需要在 Airflow 服务器上安装你需要的版本的 Python,
32        # 并在参数中指定 Python interpreter 的路径. 不过在 MWAA 上你做不到这一点, 因为
33        # 服务器是由 AWS 所管理的. 这种情况下我推荐使用 docker package, 用容器来运行
34        # Ref: https://airflow.apache.org/docs/apache-airflow-providers-docker/stable/index.html
35        python_version="3.10",
36        # 这里是所有依赖的具体版本, 我建议使用固定版本, 不要使用范围, 这样可以保证每次
37        # 运行的环境 100% 一样. 至于如何确定固定版本, 我推荐使用 poetry 来管理并解析依赖
38        # 然后输出一个完全固定的版本列表, 并包括 hashes, 以防止注入攻击.
39        requirements=[
40            "s3pathlib>=2.0.1,<3.0.0",
41        ],
42        # 表示是否把 system python 中的包也自动添加进来. 这个 system python 也就是
43        # Airflow 所在的 Python, Airflow 可包含了一堆杂七杂八哦的包, 说不准哪个就跟你的
44        # 项目依赖有冲突, 99.99% 的情况下都选择 False.
45        system_site_packages=False,
46    )
47    def task1():
48        import json
49        import random
50        from datetime import datetime
51
52        import boto3
53        from s3pathlib import S3Path, context
54
55        print("Start task1")
56
57        aws_region = "us-east-1"
58        boto_ses = boto3.session.Session(region_name=aws_region)
59        aws_account_id = boto_ses.client("sts").get_caller_identity()["Account"]
60        context.attach_boto_session(boto_ses)
61
62        value = random.randint(1, 100)
63        print(f"Generated value is {value}")
64        data = {
65            "value": value,
66            "datetime": datetime.now().isoformat(),
67        }
68        s3path = S3Path(
69            f"s3://{aws_account_id}-{aws_region}-data/projects/mwaa-poc/{dag_id}/{task1_id}.output.json"
70        )
71        s3path.write_text(
72            json.dumps(data),
73            content_type="application/json",
74        )
75
76        print("End task1")
77        return "Returned by task 1"
78
79    run_task1 = task1()
80
81
82run_dag4 = dag4()

5. 在相邻的两个 Task 之间传递数据

6. 在任意的两个 Task 之间传递小数据

在编排任务中, 在任意两个 Task 之间, 包括不相邻的两个 Task 之间传递数据是很常见的需求. 从 Airflow 1.X 起, 就自带 XComs (Cross Communication) 这一功能, 能在 Tasks 之间传递数据. 它的原理其实是在 Scheduler 上维护一个 Key Value Store, 其中 Key 是 dag_id + task_id 合起来的一个 compound key. 你可以把 value 存在里面, 自然也可以在任何其他的 task 中引用这个 value. 而从 Airflow 2.X 起, 引入了 TaskFlow 这一更加人类友好的 API. 在 TaskFlow API 下, 所有的 PythonOperator Task 的返回值都会默认被包装为一个 XComs, 而你可以直接像写 Python 函数一样在 Tasks 之间传递参数, 而无需显式在其他 Task 引用之前的 Task 的返回值.

但是注意, 能被 XComs 传递的数据必须要是可序列化的对象, 例如 Str, Int, 或是 JSON dict. 而且大小不能超过 48KB. 但这不是什么问题. 对于复杂数据结构, 你只要自己定义一套轻量的 JSON 序列化接口来返回 Task 的输出即可, 你甚至可以用 pickle 或是 JSONPickle 将其 dump 成二进制数据然后 base64 编码. 而对于体积很大的数据, 你可以将数据写入到 AWS S3, 然后返回一个 S3 uri, 传递给后续的 task, 然后后续的 task 再从 S3 读取数据即可.

Reference:

7. 在任意的两个 Task 之间传递大数据

有时候一个 Task 返回的数据量非常大, 由于 XComs 又 48KB 的限制, 这时就要另外想办法了. XComs 被设计为用来传递小数据的, 它不允许你用它来储存任意大的数据, 占用服务器资源.

一个比较直观的解决方案是利用全局可用的 context 对象. 它是一个字典的数据结构, 在整个 DAG 执行的生命周期内都存在. 记录了一些全局变量之类的信息. 其中就有一个 run_id 的字段, 它的值是一个包含时间戳的唯一的值, 看起来像这样 manual__2023-01-01T01:23:45.123456+00:00scheduled__2021-01-01T00:00:00+00:00. 精确到微秒. 其中 manual 代表你手动运行的, 而 scheduled 代表按照 scheduler 的调度规则执行的. 你可以用它和 DAG ID 合起来作为一个唯一的 Key, 然后用任何 Key Value Store 的 backend 来储存这个数据. 例如 AWS S3 或 DynamoDB 都可以. 如果你的需要访问这个数据的并发性高 (例如你用到了 Map 并行, 所有并行 Task 都需要读写同一个数据) 且数据量不大 (48KB - 400KB) 之间, 那么用 DynamoDB 就比较合适, 能确保读写的原子性. 而如果你仅仅是进行数据传递但数据量很大, 那么用 S3 就比较合适.

我们稍微的扩展一下, 其实我们可以不限定只使用一个 Key, 而是基于 DAG ID 和 run_id 可以创造出多个 Key, 然后将 Key 作为参数用 XComs 返回即可. 基于这种策略你可以几乎做到任何事.

Reference:

8. Poll for Job Status 模式

Poll for Job 模式常用于你有一个异步执行, 耗时较长的 Task 的情况. 举例来说, 你要用 AWS Glue 来运行一个耗时在 1 - 5 分钟的 ETL Job. 如果 Glue Job 成功, 则继续后面的步骤. 如果 Glue Job 失败, 则停止整个 DAG. 如果 Glue Job 超时, 则视为失败, 也停止 DAG. 这个模式就叫做 Poll for Job.

这个模式的一个通用解决方案是, 在异步执行这个 Task 后 (我们继续拿 AWS Glue Job 举例) 隔一段时间就去查询一下运行状态, 如果是 in progress 就等下一次, 如果是 succeeded 就继续后面的步骤, 如果是 failed 或 timeout 就停止 DAG.

而根据直觉, 这个 “隔一段时间” 一般我们会用 import time, time.sleep(60) 来实现. 等于这个进程被挂在那里 60 秒, 占用了服务器资源. 这也叫 同步轮询. 想象一下, 很可能你 check status 只需要 1 秒, 但 60 秒都在占用服务器资源, 如果你有 60 个这样的程序, 相当于你有 60 个 Python 程序一直在吃内存.

一个成熟的调度系统一般会把这个等待的事情交给一个调度系统来做, 例如一个每 1 毫秒心跳一次的 event loop. 当它等待 60 秒后, 就异步的来执行这个 check status 的操作, 所以这也叫 异步轮询. 这样 60 个这样的程序本质上相当于每个只占在运行的那 1 秒内占用了系统资源, 60 个程序占用的资源和 1 个旧版程序一样多, 大大提高了资源利用率.

我们来看一下官方文档中的一个例子, 理解一下在 Airflow 中应该如何实现这个模式. 为了帮助理解, 我在官方的文档的基础上增加了很多保姆级注释.

from airflow.decorators import dag, task
from airflow.sensors.base import PokeReturnValue

# @task.sensor 的全部参数请参考下面两个文档, 其中 BaseSensorOperator 是基类,
# ExternalTaskSensor 是子类, 全部参数是两者之和
# - BaseSensorOperator: https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/sensors/base/index.html#airflow.sensors.base.BaseSensorOperator
# - ExternalTaskSensor: https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/sensors/external_task/index.html#airflow.sensors.external_task.ExternalTaskSensor
@task.sensor(
    # 每多少秒检查一次状态
    poke_interval=60,
    # 以供多少秒后超时
    timeout=300,
    # Airflow 同时支持 同步轮询 和 异步轮询 两种模式
    # poke 就是同步模式, 等待期间也占用一个 worker 字眼
    # reschedule 就是异步模式,
    # 简单来说, 你查询的频率越高, 那么异步执行反复释放和重新 import 到内存的开销就越不值当
    # 就应该用 poke 模式索性一直占用一个 worker
    # 而如果你查询的频率不高, 那么就应该在等待期间释放资源, 等下一次再 load 也没关系
    # 就应该用 reschedule 模式
    # mode="poke",
    mode="reschedule",
)
def wait_for_upstream() -> PokeReturnValue:
    """
    用 @task.sensor 包装起来的函数最好返回一个 PokeReturnValue 对象, 它只有两个参数:

    - is_done: bool, 表示是否已经完成, 这里的完成指的是是否可以停止轮询, 无论是成功还是
        失败都可以视为 "完成". 逻辑上你查询到的外部任务的状态如果你认为没有必要再查询了,
        例如成功, 失败, 或者是失败进行中, 失败后回滚中, 你都能预料到最终结局了, 这时就
        该返回 is_done = True.
    - xcom_value: 这是 Airflow 在两个 task 之间传递 data 的机制, 必须要是个可序列化对象,
        而且不能超过 48KB. 一般你用来将外部任务的状态信息传递给下一个 task, 例如成功, 失败,
        以及任何额外的信息, 例如失败原因, debug 信息等.

    这里有一个点. 很多人会想说如果失败了, 那么就直接在这个 Sensor task 里 raise 一个异常.
    如果成功了, 就直接在这个 task 里继续运行下一个 task 的逻辑, 这样就可以少写一个 task,
    降低系统复杂度了. 其实不然. 这样做等于是将调度逻辑和业务逻辑放在了一起, 失去了使用调度
    系统的意义. 并且你的代码违背了一段代码只做一件事的原则, 如果你的业务代码不小心有 bug,
    那么就会拖累这个调度逻辑, 使得调度逻辑也失败了. 这样做弊大于利.
    """
    # your get status logic here
    # the status is a string: "doing", "succeeded", "failed"
    status: str = get_status(...)
    if status == "doing":
        return PokeReturnValue(is_done=False)
    elif status in ["succeeded", "failed"]:
        return PokeReturnValue(is_done=True, xcom_value=status)
    else:
        raise NotImplementedError(f"unknown status: {status!r}"))

好了我们了解了原理之后, 就来看一个非常具体的例子. 下面这个脚本模拟了一个耗时 20 秒的外部任务. 请仔细阅读里面的注释, 里面介绍了这个任务的逻辑.

 1# -*- coding: utf-8 -*-
 2
 3"""
 4用来模拟外部任务. 在 invoke dag 之前请运行这个脚本, 然后再在 5 秒内 invoke dag.
 5"""
 6
 7import time
 8import random
 9from boto_session_manager import BotoSesManager
10from s3pathlib import S3Path, context
11
12bsm = BotoSesManager(profile_name="awshsh_app_dev_us_east_1")
13context.attach_boto_session(bsm.boto_ses)
14
15s3path = S3Path(
16    f"s3://{bsm.aws_account_id}-{bsm.aws_region}-data/"
17    f"projects/mwaa-poc/dag8/task0.status.txt",
18)
19
20
21def run():
22    """
23    每 2 秒循环一次, 共循环 10 次, 每次循环有 5% 的概率失败. 如果 10 次循环都成功, 则任务成功.
24    经过简单的概率计算, 10 次循环都成功的概率为 60%.
25
26    在开始的时候就往 S3 中写入 doing, 任务成功后写入 succeeded, 任务失败后写入 failed.
27    """
28    s3path.write_text("doing")
29    for i in range(1, 10 + 1):
30        print(f"{i} / 10 th iteration")
31        time.sleep(2)
32        if random.randint(1, 100) <= 5:
33            s3path.write_text("failed")
34            print("failed")
35            return
36    print("succeeded")
37    s3path.write_text("succeeded")
38
39
40s3path.delete()
41run()

然后我们来看一下这个 DAG 定义. 同样的, 请自诩阅读里面的注释.

 1# -*- coding: utf-8 -*-
 2
 3"""
 4如何实现 Long Polling 的模式, 也就是启动一个 Task 之后, 等待这个 Task 成功, 失败, 或超时.
 5"""
 6
 7from datetime import datetime
 8from airflow.decorators import dag, task
 9from airflow.sensors.base import PokeReturnValue
10
11dag_id = "dag8"
12
13
14@dag(
15    dag_id=dag_id,
16    start_date=datetime(2021, 1, 1),
17    schedule="@once",
18    catchup=False,
19)
20def dag8():
21    """
22    在这个例子中我们会在本地电脑上运行一个耗时 20 秒的脚本, 用于模拟 ExternalTask 的异步执行.
23    """
24    task1_id = "task1"
25    task2_id = "task2"
26
27    @task.sensor(
28        task_id=task1_id,
29        poke_interval=2,
30        timeout=30,
31        mode="poke",
32    )
33    def task1_check_status():
34        """
35        这个任务会隔一段时间检查一次 Task0 (也就是外部任务) 的状态, 超过 30 秒还没有成功或失败
36        就算超时. 由于我们的检查频率较高, 所以用 poke 模式.
37        """
38        import boto3
39        from datetime import datetime
40
41        print("Start task1")
42        print(f"current time is {datetime.utcnow()}")
43
44        # 从 S3 中读取 task0 的状态
45        aws_account_id = boto3.client("sts").get_caller_identity()["Account"]
46        aws_region = "us-east-1"
47
48        status = (
49            boto3.client("s3")
50            .get_object(
51                Bucket=f"{aws_account_id}-{aws_region}-data",
52                Key=f"projects/mwaa-poc/{dag_id}/task0.status.txt",
53            )["Body"]
54            .read()
55            .decode("utf-8")
56        )
57        print(f"get status: {status!r}")
58
59        # 如果状态是 succeeded 或 failed, 则结束等待, 并将状态作为返回值返回
60        if status in ["succeeded", "failed"]:
61            print("End task1 check status")
62            return PokeReturnValue(is_done=True, xcom_value=status)
63        else:
64            return PokeReturnValue(is_done=False)
65
66    @task(
67        task_id=task2_id,
68    )
69    def task2_decide_next_step(status):
70        """
71        对检测到的 task 0 的状态信息进行处理.
72        """
73        print("Start task2")
74        print(f"received status is {status!r}")
75        if status == "succeeded":
76            print("external task succeeded")
77            print("End task2")
78            return "Returned by task 2"
79        elif status == "failed":
80            print("external task failed")
81            raise ValueError("external task failed")
82        else:  # pragma: no cover
83            raise NotImplementedError
84
85    # 把 task1_check_status 的返回值传递给 task2_decide_next_step
86    task2_decide_next_step(task1_check_status())
87
88
89run_dag8 = dag8()

好了现在我们对如何用 Airflow 实现 Poll Job 模式有一个比较清晰的认识了. 我们不妨再展开一点点,

如果这个耗时较长的业务逻辑是写在 Task 里的, 注意这里说的不是你调用一个外面的 API 让真正的运算资源远程执行, 而是说你的业务逻辑例如数据处理就卸载了 Task 函数里面. 这种情况你没有必要用 Sensor. 你直接就在这个 Task 里用 try ... except ... 判断是否成功, 如果出现了 try ... except ... 每预料到的错误, 那么就让他 fail 即可.

而如果这个耗时较长的业务只是一个外部的 API 调用, 例如你要跑一个 AWS Glue Job. 那么你在这个 Task 里就写异步调用的逻辑即可, 然后将 job id 传递给后面的 Sensor task 来轮询状态即可. 例如:

@task(...)
def run_aws_glue_job() -> str:
    res = boto3.client("glue").glue.start_job_run(...)
    return res["JobRunId"]

9. Fan out and Fan in 模式

Fan out 指的是一个任务之后, 并行的执行多个任务.

而 Fan in 指的是只有多个任务都完成之后, 才能继续执行下一个任务. 这里有一些变化, 例如你想要 3 个任务里的至少 2 个完成才执行下一个. 又或者你想要 1 号任务必须完成, 而 2 和 3 之中至少有一个完成. 这种复杂的情况该怎么做呢? 简单来说, BaseOperator 有一个参数 trigger_rule 可以是 { all_success | all_failed | all_done | all_skipped | one_success | one_done | one_failed | none_failed | none_failed_min_one_success | none_skipped | always } 中的一个, 可以决定某个人物是否执行的条件是前置的多个任务全部成功, 全部失败, 有一个失败, 等等. 可以应对大部分的业务情况. 而对于超级复杂的情况, 你可以简单的把所有的前置 task 都用 try except 包裹起来自己做异常处理, 然后返回一个 status. 然后后续任务使用 all_done, 然后把所有前置 task 的返回值作为参数传入, 然后自己用 if else 判断分析即可实现任何复杂逻辑.

请看下面这个 DAG 的例子.

 1# -*- coding: utf-8 -*-
 2
 3"""
 4Fan out and Fan in 模式. 以及在 Fan in 模式下的复杂异常处理.
 5"""
 6
 7from datetime import datetime
 8from airflow.decorators import dag, task
 9
10dag_id = "dag9"
11
12
13@dag(
14    dag_id=dag_id,
15    start_date=datetime(2021, 1, 1),
16    schedule="@once",
17    catchup=False,
18)
19def dag9():
20    """
21    在这个例子中 task1 fan out 成了 task2 和 task3, task2 和 task4 fan in 成了 task5.
22    其中 task2, 4 都有 30% 的几率出错 (大约 50% 的几率两个都成功). 为了能让 task5 能对
23    task2, task4 可能出现的各种复杂的成功失败情况进行处理, 我们让 task2 和 task4 在成功的
24    时候返回一个值. 而失败的时候抛出异常. 而 task5 则将 task2 和 task4 的返回值作为参数.
25    那么一旦收到返回值则说明成功了, 而收到 None 则意味着失败了. 这样我们就可以用 if, else
26    写出任意复杂的逻辑了. 这里注意 task5 的 trigger rule 是 all_done. 也就是说 task2, 4
27    都成功或失败, 总之是结束了之后再运行.
28    """
29    task1_id = "task1"
30    task2_id = "task2"
31    task3_id = "task3"
32    task4_id = "task4"
33    task5_id = "task5"
34    task6_id = "task6"
35
36    @task(task_id=task1_id)
37    def task1():
38        print("Start task1")
39        print("End task1")
40        return "Returned by task 1"
41
42    @task(task_id=task2_id)
43    def task2():
44        import random
45        print("Start task2")
46        if random.randint(1, 100) <= 30:
47            raise Exception("task2 failed")
48        print("End task2")
49        return "Returned by task 2"
50
51    @task(task_id=task3_id)
52    def task3():
53        import time
54        print("Start task3")
55        time.sleep(5)
56        print("End task3")
57        return "Returned by task 3"
58
59    @task(task_id=task4_id)
60    def task4():
61        import time
62        import random
63        print("Start task4")
64        time.sleep(5)
65        if random.randint(1, 100) <= 30:
66            raise Exception("task2 failed")
67        print("End task4")
68        return "Returned by task 4"
69
70    @task(task_id=task5_id, trigger_rule="all_done") # 这里的 trigger rule 是 all_done
71    def task5(task2_return_value: str, task4_return_value: str):
72        print("Start task5")
73        print(
74            f"received task2 return value: {task2_return_value!r}, task4 return value: {task4_return_value!r}"
75        )
76        print("End task5")
77
78    @task(task_id=task6_id)
79    def task6():
80        print("Start task6")
81        print("End task6")
82        return "Returned by task 6"
83
84    run_task1 = task1()
85    run_task2 = task2()
86    run_task3 = task3()
87    run_task4 = task4()
88    # task 5 把 2 和 4 的两个前置的返回值作为参数传入, 如果失败了则传入的是 None
89    run_task5 = task5(run_task2, run_task4)
90    run_task6 = task6()
91
92    # Fan out
93    run_task1 >> [run_task2, run_task3]
94    run_task2 >> run_task5
95    run_task3 >> run_task4 >> run_task5
96    run_task5 >> run_task6
97
98run_dag9 = dag9()