Python unittest.mock

By | 7月 29, 2023

本文翻译自:Understanding the Python Mock Object Library

测试用于检验代码是否正确、可靠。第三方的依赖导致测试不稳定,unittest.mock可以帮你模拟第三方的行为,保证测能以可预测的结果检验代码。

unittest.mock是Python 3.3引入标准库的,Python 3.3之前的版本需要单独安装:

$ pip install mock

Mock类

Mock类用来创建一个对象模拟真实对象。

创建一个mock对象:

>>> from unittest.mock import Mock
>>> mock = Mock()
>>> mock
<Mock id='4561344720'>

现在你可以用创建的mock对象来替换真实对象,你可以把mock对象当参数传递给函数,或者赋值给一个变量。

# Pass mock as an argument to do_something()
do_something(mock)

# Patch the json library
json = mock

当你用mock对象替换真实对象时,mock对象要具有真实对象一样的属性或方法,否则执行出错。

例如你要替换json库,并且你的代码用到json库中的dumps()方法,那么创建的mock对象也要有dumps()方法。

mock对象的属性和方法是lazy的

Mock对象能替换任何对象,那是因为当你访问Mock对象的属性或方法时,如果不存在则会自动创建它们

>>> mock.some_attribute
<Mock name='mock.some_attribute' id='4394778696'>
>>> mock.do_something()
<Mock name='mock.do_something()' id='4394778920'>

例如前面提到mock json库,当你调用dumps()方法时,mock对象会自动添加此方法。

>>> json = Mock()
>>> json.dumps()
<Mock name='mock.dumps()' id='4392249776'>

注意这里Mock对象的dumps()方法有两个特点:

  1. 它没有参数,实际上它可以接受任何参数。
  2. 它的返回值也是一个Mock对象,也就是说你可以继续对返回值做任何调用,返回值仍然是Mock对象,如此循环下去。
>>> json = Mock()
>>> json.loads('{"k": "v"}').get('k')
<Mock name='mock.loads().get()' id='4379599424'>

检查Mock对象方法的调用情况

检查方法是否被调用

  • .assert_called() 检查方法是否被调用过
  • .assert_called_once() 检查方法是否只被调用过一次
  • .assert_not_called() 检查方法是否没有被调用过

还可以添加参数,检查方法是否使用预期的参数被调用:

  • .assert_called_with(*args, **kwargs)
  • .assert_called_once_with(*args, **kwargs)
>>> from unittest.mock import Mock

>>> # Create a mock object
... json = Mock()

>>> json.loads('{"key": "value"}')
<Mock name='mock.loads()' id='4550144184'>

>>> # You know that you called loads() so you can
>>> # make assertions to test that expectation
... json.loads.assert_called()
>>> json.loads.assert_called_once()
>>> json.loads.assert_called_with('{"key": "value"}')
>>> json.loads.assert_called_once_with('{"key": "value"}')

>>> json.loads('{"key": "value"}')
<Mock name='mock.loads()' id='4550144184'>

>>> # If an assertion fails, the mock will raise an AssertionError
... json.loads.assert_called_once()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/Cellar/python/3.6.5/Frameworks/Python.framework/Versions/3.6/lib/python3.6/unittest/mock.py", line 795, in assert_called_once
    raise AssertionError(msg)
AssertionError: Expected 'loads' to have been called once. Called 2 times.

>>> json.loads.assert_called_once_with('{"key": "value"}')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/Cellar/python/3.6.5/Frameworks/Python.framework/Versions/3.6/lib/python3.6/unittest/mock.py", line 824, in assert_called_once_with
    raise AssertionError(msg)
AssertionError: Expected 'loads' to be called once. Called 2 times.

>>> json.loads.assert_not_called()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/Cellar/python/3.6.5/Frameworks/Python.framework/Versions/3.6/lib/python3.6/unittest/mock.py", line 777, in assert_not_called
    raise AssertionError(msg)
AssertionError: Expected 'loads' to not have been called. Called 2 times.

检查方法调用详情

json = Mock(),可访问json的如下属性,

  • json.loads.call_count — loads()调用次数
  • json.loads.call_args — loads()最后一次调用的参数
  • json.loads.call_args_list — loads()所有调用参数
  • json.method_calls — json所有调用过的方法及参数
>>> from unittest.mock import Mock

>>> # Create a mock object
... json = Mock()
>>> json.loads('{"key": "value"}')
<Mock name='mock.loads()' id='4391026640'>

>>> # Number of times you called loads():
... json.loads.call_count
1
>>> # The last loads() call:
... json.loads.call_args
call('{"key": "value"}')
>>> # List of loads() calls:
... json.loads.call_args_list
[call('{"key": "value"}')]
>>> # List of calls to json's methods (recursively):
... json.method_calls
[call.loads('{"key": "value"}')]

设置方法的return_value

创建一个my_calendar.py,添加is_weekday()方法用以判断今天是否是工作日,根据你执行日期不同,它返回结果可能是True,也可能是False。

from datetime import datetime

def is_weekday():
    today = datetime.today()
    # Python's datetime library treats Monday as 0 and Sunday as 6
    return (0 <= today.weekday() < 5)

# Test if today is a weekday
assert is_weekday()

对于上面的例子,你可以mock掉datetime,控制today()方法的返回值。

import datetime
from unittest.mock import Mock

# Save a couple of test days
tuesday = datetime.datetime(year=2019, month=1, day=1)
saturday = datetime.datetime(year=2019, month=1, day=5)

# Mock datetime to control today's date
datetime = Mock()

def is_weekday():
    today = datetime.datetime.today()
    # Python's datetime library treats Monday as 0 and Sunday as 6
    return (0 <= today.weekday() < 5)

# Mock .today() to return Tuesday
datetime.datetime.today.return_value = tuesday
# Test Tuesday is a weekday
assert is_weekday()
# Mock .today() to return Saturday
datetime.datetime.today.return_value = saturday
# Test Saturday is not a weekday
assert not is_weekday()

设置方法的side_effect

如果设置了side_effect,return_value将被忽略。可以对side_effect设置如下值:

  • Exception:调用时则抛出异常。
  • Function:调用mock方法会调用此方法,方法参数签名要一致。
  • Iterable:多次调用mock方法,会依次返回Iterable是的值,可以是Exception或正常Value。

下面例子,get_holidays()通过requests.get()访问API获得节假日列表。

设置side_effect = Timeout,模拟访问超时。

  • .assertRaises() — 检查是否抛出异常
import unittest
from requests.exceptions import Timeout
from unittest.mock import Mock

# Mock requests to control its behavior
requests = Mock()

def get_holidays():
    r = requests.get('http://localhost/api/holidays')
    if r.status_code == 200:
        return r.json()
    return None

class TestCalendar(unittest.TestCase):
    def test_get_holidays_timeout(self):
        # Test a connection timeout
        requests.get.side_effect = Timeout
        with self.assertRaises(Timeout):
            get_holidays()

if __name__ == '__main__':
    unittest.main()

给side_effect设置一个函数(方法签名和原函数要一致),调用mock方法时,会执行此函数。

import requests
import unittest
from unittest.mock import Mock

# Mock requests to control its behavior
requests = Mock()

def get_holidays():
    r = requests.get('http://localhost/api/holidays')
    if r.status_code == 200:
        return r.json()
    return None

class TestCalendar(unittest.TestCase):
    def log_request(self, url):
        # Log a fake request for test output purposes
        print(f'Making a request to {url}.')
        print('Request received!')

        # Create a new Mock to imitate a Response
        response_mock = Mock()
        response_mock.status_code = 200
        response_mock.json.return_value = {
            '12/25': 'Christmas',
            '7/4': 'Independence Day',
        }
        return response_mock

    def test_get_holidays_logging(self):
        # Test a successful, logged request
        requests.get.side_effect = self.log_request
        assert get_holidays()['12/25'] == 'Christmas'

if __name__ == '__main__':
    unittest.main()

给side_effect设置多个返回值,模拟多次调用返回不同的值。

import unittest
from requests.exceptions import Timeout
from unittest.mock import Mock

# Mock requests to control its behavior
requests = Mock()

def get_holidays():
    r = requests.get('http://localhost/api/holidays')
    if r.status_code == 200:
        return r.json()
    return None

class TestCalendar(unittest.TestCase):
    def test_get_holidays_retry(self):
        # Create a new Mock to imitate a Response
        response_mock = Mock()
        response_mock.status_code = 200
        response_mock.json.return_value = {
            '12/25': 'Christmas',
            '7/4': 'Independence Day',
        }
        # Set the side effect of .get()
        requests.get.side_effect = [Timeout, response_mock]
        # Test that the first request raises a Timeout
        with self.assertRaises(Timeout):
            get_holidays()
        # Now retry, expecting a successful response
        assert get_holidays()['12/25'] == 'Christmas'
        # Finally, assert .get() was called twice
        assert requests.get.call_count == 2

if __name__ == '__main__':
    unittest.main()

上面例子,第一次调用抛出了Timeout异常,第二次调用返回了response_mock。

patch()

patch通过字符串路径确定mock target。

my_calendar.py是被测试文件。

import requests
from datetime import datetime

def is_weekday():
    today = datetime.today()
    # Python's datetime library treats Monday as 0 and Sunday as 6
    return (0 <= today.weekday() < 5)

def get_holidays():
    r = requests.get('http://localhost/api/holidays')
    if r.status_code == 200:
        return r.json()
    return None

装饰器

patch()返回的是一个MagicMock对象,该对象默认实现了几乎所有的magic methods。

import unittest
from my_calendar import get_holidays
from requests.exceptions import Timeout
from unittest.mock import patch

class TestCalendar(unittest.TestCase):
    @patch('my_calendar.requests')
    def test_get_holidays_timeout(self, mock_requests):
            mock_requests.get.side_effect = Timeout
            with self.assertRaises(Timeout):
                get_holidays()
                mock_requests.get.assert_called_once()

with上下文

import unittest
from my_calendar import get_holidays
from requests.exceptions import Timeout
from unittest.mock import patch

class TestCalendar(unittest.TestCase):
    def test_get_holidays_timeout(self):
        with patch('my_calendar.requests') as mock_requests:
            mock_requests.get.side_effect = Timeout
            with self.assertRaises(Timeout):
                get_holidays()
                mock_requests.get.assert_called_once()

patch.object()

通过对象(可以是任何object:function, class, instance等)确定mock target。

import unittest
from my_calendar import requests, get_holidays
from unittest.mock import patch, Mock

response_mock = Mock()
response_mock.status_code = 200
response_mock.json.return_value = {
    '12/25': 'Christmas',
    '7/4': 'Independence Day',
}

class TestCalendar(unittest.TestCase):
    # @patch.object(requests, 'get', side_effect=requests.exceptions.Timeout)
    @patch.object(requests, 'get', return_value=response_mock)
    def test_get_holidays_timeout(self, mock_get):
        self.assertEqual('Christmas', get_holidays()['12/25'])
        mock_get.assert_called_with('http://localhost/api/holidays')

确定patch的目标

知道patch对象在哪里非常重要,如果patch错了,代码就不能正确执行。

基本原则是:在对象被查找的地方patch,这不一定就是它被定义的地方。

例如下面的代码,b.py使用了a.py里的greet方法。

a.py
def greet():
    return 'Hi'

b.py
from a import greet
def greet_the_word():
    return greet() + ' the world'

我们要patch b.py里的greet方法。

from unittest.mock import patch
from b import greet_the_word

def test_greet_the_world():
    with patch('a.greet') as mock_greet:
        mock_greet.return_value = 'Hello'
        assert 'Hi the world' == greet_the_word()

    with patch('b.greet') as mock_greet:
        mock_greet.return_value = 'Hello'
        assert 'Hello the world' == greet_the_word()

第五行的patch不成功,它patch的是a.py里的greet,但是b.py已经重新导入greet,应当patch b.py里的greet。

第九行的patch是正确的,结果变成了’Hello the world’。

spec – 限制Mock对象的属性

Mock对象会创建任意的属性和方法,这导致你访问一个错误的方法时它不会报错,这将带来危险。例如接口更改了,但是访问Mock对象的旧方法不会报错,因为它自动创建了。

可以通过spec参数控制Mock对象具有给定的属性,spec参数可以是:

  • str[]:只有列表里的属性。
  • function:没有任何属性,是个可以调用的function。
  • object (class or instance): 属性列表是dir(object)返回值。
from unittest.mock import Mock

def test_spec_list():
    m = Mock(spec=['sing', 'dance'])
    m.sing()
    m.dance()
    # m.other_attr
    # AttributeError: Mock object has no attribute 'other_attr'

def greet(name):
    return f'Hi {name}'

def test_spec_function():
    mock_greet = Mock(spec=greet)
    mock_greet.return_value = 'abc'
    assert 'abc' == mock_greet('Alex')
    mock_greet.assert_called_with('Alex')
    # mock_greet.other_attr
    # AttributeError: Mock object has no attribute 'other_attr'

class Person:
    def greet(self):
        print('Hello World')

def test_spec_class():
    mock_p = Mock(spec=Person)
    mock_p.greet.return_value = "Knock knock"
    assert "Knock knock" == mock_p.greet()
    mock_p.greet.assert_called_once()
    # mock_p.other_attr
    # AttributeError: Mock object has no attribute 'other_attr'

    p = Person()
    p2 = Mock(spec=p)
    p2.greet()

Mock(spec=)这种写法只限制mock对象有哪些属性、方法,但是调用方法时可以是任意参数

Autospeccing

自动spec可以mock对象的API跟原有的一样,调用mock对象的方法时如果参数不符合,会raise TypeError。使用下面方法自动spec,它们返回的都是MagicMock对象。

  • create_autospec()
  • patch(autospec=True)
  • patch.object(autospec=True)

下面代码b.py使用了a.py里的方法和类,我们要测试b.py是的方法。

a.py
----------------------------------------------------
def add(num1, num2):
    return num1 + num2

class Person:
    def __init__(self, name):
        self.name = name
    def greet(self, somebody):
        return f'Hi {somebody}, my name is {self.name}'


b.py
-----------------------------------------------------
import a

def do_add(num1, num2):
    return a.add(num1, num2)

def do_greet(somebody):
    p = a.Person('Alex')
    return p.greet(somebody)

测试类,test_do_greet()方法mock了a.py里的Person class,控制Person类生成指定的对象。这在Java是很难实现的,要用PowerMock以及特殊的Runner执行,但是Python却很轻松实现。

import a
import b
from unittest.mock import patch, create_autospec

def test_do_add():
    with patch('a.add', autospec=True) as mock_add:
        mock_add.return_value = 90
        assert b.do_add(1, 2) == 90

def test_do_greet():
    mock_person = create_autospec(a.Person)
    mock_person.greet.return_value = 'I am alien'

    # Mock a class to create certain instance
    with patch('a.Person') as mock_person_cls:
        mock_person_cls.return_value = mock_person
        assert b.do_greet('Tom') == 'I am alien'

总结及最佳实践

  • Mock():用于创建万能对象,访问什么创建什么。
  • MagicMock():是Mock的子类,对magic methods有默认实现。
  • patch():通过路径给target打补丁。
  • patch.object():通过对象来打补丁,对象必须在当前module里。要import module才能给module打补丁。

最佳实践

  • 创建mock对象:不要手动创建Mock或MagicMock对象,使用create_autospec来创建具有spec的MagicMock对象。
  • 局部mock:使用patch.object(target, attribute, autospec=True),不会出现拼写错误。

mock.patch和mock.patch.object功能一样,只是后者更容易写。