Python Library - mock / pytest-mock

Keywords: Python, Test, Mock, Pytest

Mock 的作用

经常我们会有一个函数被很多其他函数调用, 而这个函数依赖于一些真实的资源. 比如从服务器上获取一个文件. 这里我们在测试的时候可能并不关心获取这个文件的过程, 而是关心对这个文件进行后续处理的逻辑. 这时候我们如何测试呢?

  1. 在函数设计的时候可能后续处理的逻辑的输入只有数据, 并不涉及获取文件这个动作. 我们是应该这么设计, 但是并不解决我们的问题. 因为最终的 API 的输入很可能只是包括这个文件的位置, 并没有数据.

  2. 我们可以给这个获取文件的函数留一个可选参数 result, 如果 result 存在, 则直接返回而不执行获取文件的动作. 这样做对代码有侵入, 会留下很多 if else 垃圾代码. 并且这只适合返回固定值的情况, 无法做到对输入的参数进行一些静态计算后返回.

这就是 Mock 起作用的地方, mock 可以将你的某个函数替换成你自定义的函数, 或是自定义一个固定返回值. 而且你可以指定这个 Mock 起作用的作用域, 非常灵活. 再者甚至能对这个函数的调用次数进行统计, 追踪每次调用的情况.

几个例子

假设我们有一个 get_status 的函数还没有被实现, 而最终的 api 是 get_status_api. 我们要对 get_status_api 进行测试.

# -*- coding: utf-8 -*-

def get_status(name: str) -> str:
    print("real get_status() function hit")
    raise NotImplementedError
    return f"status of {name}: GOOD"


def get_status_api(name: str) -> str:
    return get_status(name)


def run_sql(sql: str) -> int:
    print("real run_sql() function hit")
    raise NotImplementedError
    return connect.execute(sql)


class Order:
    def __init__(self, order_id: str):
        self.order_id = order_id

    @property
    def n_items(self) -> int:
        sql = f"select COUNT(*) from items where items.order_id == {self.order_id}"
        return run_sql(sql)


if __name__ == "__main__":
    print(get_status_api(name="alice"))
# -*- coding: utf-8 -*-

"""
Use mock without pytest
"""

import pytest
import my_module
from my_module import get_status_api
from unittest.mock import patch


# the fake function to replace the not implemented get_status
def fake_get_status(name: str) -> str:
    return f"status of {name}: BAD"


# temporarily replace your function with a fake function
with patch.object(my_module, "get_status", wraps=fake_get_status) as mock_get_status:
    print(get_status_api(name="Alice"))

with pytest.raises(NotImplementedError):
    get_status_api(name="Alice")
# -*- coding: utf-8 -*-

"""
Patch a function with a fixed return value
"""

import os
import pytest
from my_module import get_status, get_status_api


def test_1(mocker):
    # this would work
    mocker.patch("my_module.get_status", return_value="status of Alice: GOOD")
    assert get_status_api(name="Alice") == "status of Alice: GOOD"

    # don't use that directly
    # print(get_status(name="BOB"))


def test_2(mocker):
    # not work, the patch context is not available
    with pytest.raises(NotImplementedError):
        get_status_api(name="Alice")


if __name__ == "__main__":
    basename = os.path.basename(__file__)
    pytest.main([basename, "-s", "--tb=native"])
# -*- coding: utf-8 -*-

"""
Replace a specific function with your own function.
"""

import os
import pytest
from my_module import get_status_api


# the fake function to replace the not implemented get_status
def fake_get_status(name: str) -> str:
    return f"status of {name}: BAD"


def test_1(mocker):
    mocker.patch("my_module.get_status", wraps=fake_get_status)
    assert get_status_api(name="Alice") == "status of Alice: BAD"


def test_2(mocker):
    with pytest.raises(NotImplementedError):
        get_status_api(name="Alice")


if __name__ == "__main__":
    basename = os.path.basename(__file__)
    pytest.main([basename, "-s", "--tb=native"])
# -*- coding: utf-8 -*-

"""
You can define a function with mocker pytest fixture, define your mock
logic in that function, and call that function when needed in your test.
"""

import os
import pytest
from my_module import get_status_api


# the fake function to replace the not implemented get_status
def fake_get_status(name: str) -> str:
    return f"status of {name}: BAD"


# implement your mock logic
def do_patch(mocker):
    mocker.patch("my_module.get_status", wraps=fake_get_status)


def test_1(mocker):
    do_patch(mocker)  # use for test 1
    assert get_status_api(name="Alice") == "status of Alice: BAD"


def test_2(mocker):
    do_patch(mocker)  # use for test 2
    assert get_status_api(name="Alice") == "status of Alice: BAD"


def test_3(mocker):
    # not use for test 3
    with pytest.raises(NotImplementedError):
        get_status_api(name="Alice")


if __name__ == "__main__":
    basename = os.path.basename(__file__)
    pytest.main([basename, "-s", "--tb=native"])

Patch a Variable

这个例子我们介绍了如何 mock 一个 模块内的 变量的值.

使用 Mock 有一个重点: “假设你有一个 变量/函数 你需要 mock 它的行为, 千万不要在你定义它的地方 Mock, 而是要在你 使用 它的地方 mock”

首先我们有这样一个模块:

my_package
|--- app.py
|--- constant.py
# -*- coding: utf-8 -*-
# content of ``my_package/constant.py``

key1 = "value1"
key2 = "value2"
key3 = "value3"
# -*- coding: utf-8 -*-
# content of ``my_package/app.py``

from . import constant, helpers
from .config import Config


def print_constant():
    print(f"constant.key1 = {constant.key1}")
    print(f"constant.key2 = {constant.key2}")
    print(f"constant.key3 = {constant.key3}")


def print_now():
    print(f"now is {helpers.utc_now()}")


def print_config():
    config = Config(env="prod")
    print(f"config.project_name = {config.project_name}")


def print_db_table_name():
    config = Config(env="prod")
    print(f"config.make_db_table_name = {config.make_db_table_name(env='prod')}")

可以看出 my_package.constant.key1 是在 constant.py 模块中被定义的, 而是在 my_package.app.print_constant 函数中被使用的. 我们现在要来测试 print_constant 函数, 不过我们想要替换掉 key1 的值.

# -*- coding: utf-8 -*-

import os
import pytest
from unittest.mock import patch

from my_package.app import print_constant


def test():
    print("do test with mock")
    with patch("my_package.constant.key1", "value111"):
        print_constant()

    print("without mock")
    print_constant()


if __name__ == "__main__":
    basename = os.path.basename(__file__)
    pytest.main([basename, "-s", "--tb=native"])

最后的结果是这样的:

do test with mock
constant.key1 = value111
constant.key2 = value2
constant.key3 = value3
without mock
constant.key1 = value1
constant.key2 = value2
constant.key3 = value3

Patch a Function

在这个例子里我们重点介绍了如何 mock 一个 模块内的 函数的返回值.

使用 Mock 有一个重点: “假设你有一个 变量/函数 你需要 mock 它的行为, 千万不要在你定义它的地方 Mock, 而是要在你 使用 它的地方 mock”

首先我们有这样一个模块:

my_package
|--- app.py
|--- helpers.py
# -*- coding: utf-8 -*-
# content of ``my_package/constant.py``

from datetime import datetime, timezone


def utc_now() -> datetime:
    return datetime.utcnow().replace(tzinfo=timezone.utc)
# -*- coding: utf-8 -*-
# content of ``my_package/app.py``

from . import constant, helpers
from .config import Config


def print_constant():
    print(f"constant.key1 = {constant.key1}")
    print(f"constant.key2 = {constant.key2}")
    print(f"constant.key3 = {constant.key3}")


def print_now():
    print(f"now is {helpers.utc_now()}")


def print_config():
    config = Config(env="prod")
    print(f"config.project_name = {config.project_name}")


def print_db_table_name():
    config = Config(env="prod")
    print(f"config.make_db_table_name = {config.make_db_table_name(env='prod')}")

可以看出 my_package.helpers.utc_now 是在 helpers.py 模块中被定义的, 而是在 my_package.app.print_now 函数中被使用的. 我们现在要来测试 print_now 函数, 不过我们想要替换掉 utc_now 的值.

# -*- coding: utf-8 -*-

import os
import pytest
from datetime import datetime, timezone
from unittest.mock import patch

from my_package.app import print_now


def test():
    print("do test with mock")
    with patch(
        "my_package.helpers.utc_now",
        return_value=datetime(2000, 1, 1, tzinfo=timezone.utc),
    ):
        print_now()

    print("without mock")
    print_now()


if __name__ == "__main__":
    basename = os.path.basename(__file__)
    pytest.main([basename, "-s", "--tb=native"])

最后的结果是这样的:

do test with mock
now is 2000-01-01 00:00:00+00:00
without mock
now is 2022-10-04 16:57:16.373059+00:00

Patch a Class Property and Method

在这个例子我们重点介绍了如何 mock 一个 模块内的 类 的 property 函数以及普通的 method 函数的.

首先我们有这样一个模块:

my_package
|--- app.py
|--- config.py
# -*- coding: utf-8 -*-

import dataclasses


@dataclasses.dataclass
class Config:
    env: str = dataclasses.field()

    @property
    def project_name(self) -> str:
        return f"my-project-{self.env}"

    def make_db_table_name(self, env: str) -> str:
        return f"my_table_{env}"
# -*- coding: utf-8 -*-
# content of ``my_package/app.py``

from . import constant, helpers
from .config import Config


def print_constant():
    print(f"constant.key1 = {constant.key1}")
    print(f"constant.key2 = {constant.key2}")
    print(f"constant.key3 = {constant.key3}")


def print_now():
    print(f"now is {helpers.utc_now()}")


def print_config():
    config = Config(env="prod")
    print(f"config.project_name = {config.project_name}")


def print_db_table_name():
    config = Config(env="prod")
    print(f"config.make_db_table_name = {config.make_db_table_name(env='prod')}")

mock 一个 property

# -*- coding: utf-8 -*-

import os
import pytest
from unittest.mock import patch, PropertyMock

from my_package.app import print_config


def test():
    print("do test with mock")
    with patch(
        "my_package.config.Config.project_name",
        new_callable=PropertyMock,
        return_value="my-project-dev",
    ):
        print_config()

    print("without mock")
    print_config()


if __name__ == "__main__":
    basename = os.path.basename(__file__)
    pytest.main([basename, "-s", "--tb=native"])

最后的结果是这样的:

do test with mock
config.project_name = my-project-dev
without mock
config.project_name = my-project-prod

mock 一个 method

# -*- coding: utf-8 -*-

import os
import pytest
from unittest.mock import patch, PropertyMock

from my_package.app import print_db_table_name


def mock_make_db_table_name(env: str) -> str:
    return f"this_is_mock_table_{env}"


def test():
    print("do test with mock")
    with patch(
        "my_package.config.Config.make_db_table_name",
        wraps=mock_make_db_table_name,
    ):
        print_db_table_name()

    print("without mock")
    print_db_table_name()


if __name__ == "__main__":
    basename = os.path.basename(__file__)
    pytest.main([basename, "-s", "--tb=native"])

最后的结果是这样的:

do test with mock
config.make_db_table_name = this_is_mock_table_prod
without mock
config.make_db_table_name = my_table_prod