2024-02-11
These are some helpful pytest plugins that I've collected over the years. They're here as code rather than as a library because it's useful to be able to modify them for a particular domain/codebase. Some of the hacks are pretty ugly as they utilise fairly deep implementation details.
icdiff
, sqlparse
,
rich
, sqlalchemy
.
import contextlib
import datetime
import enum
import fcntl
import importlib
import io
import os
import pdb
import socket
import struct
import termios
from dataclasses import is_dataclass
from socket import socket as original_socket
from decimal import Decimal
from typing import Any, cast
from unittest import mock
from uuid import UUID
import icdiff
import pytest
import rich.console
import sqlparse
from sqlalchemy import create_engine, event, sql, text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Query, Session
from _pytest.monkeypatch import MonkeyPatch
We're going to make a file: some/dir/plugin.py
, you can make all the fixtures/plugins available by
adding the following to a conftest.py
:
pytest_plugins = ["some.dir.plugin"]
From now, all the code will be assumed to be in plugin.py
.
Pytest's comparisons often leave a lot to be desired. Adding pytest_assertrepr_compare
overrides
pytest's compare function, if the function returns None
, it reverts to the default comparison. This
is basically just a minimal version of pytest-icdiff, it's
often handy to add domain specific pretty printers to aid debugging, so its useful in "raw" form.
Example output:
Code:
def get_terminal_width() -> int: # from: https://gist.github.com/jtriley/1108174
fd = os.open(os.ctermid(), os.O_RDONLY)
cr = struct.unpack("hh", fcntl.ioctl(fd, termios.TIOCGWINSZ, "1234")) # type: ignore
os.close(fd)
return int(cr[1])
try:
WIDTH = get_terminal_width() - 10
except Exception:
WIDTH = 80
def pytest_assertrepr_compare(config: Any, op: str, left: Any, right: Any) -> list[str] | None:
very_verbose = config.option.verbose >= 2
if not very_verbose:
return None
if op != "==":
return None
try:
if abs(left + right) < 100:
return None
except TypeError:
pass
replace_mocked_fields(left, right)
try:
if isinstance(left, str) and isinstance(right, str):
pretty_left = left.splitlines()
pretty_right = right.splitlines()
else:
pretty_left = rich_repr(left).splitlines()
pretty_right = rich_repr(right).splitlines()
differ = icdiff.ConsoleDiff(cols=WIDTH, tabsize=4)
icdiff_lines = differ.make_table(pretty_left, pretty_right, context=True, numlines=10)
return (
["equals failed"]
+ ["ACTUAL".center(WIDTH // 2 - 1) + "|" + "EXPECTED".center(WIDTH // 2)]
+ ["-" * WIDTH]
+ [icdiff.color_codes["none"] + line for line in icdiff_lines]
)
except Exception: # if it breaks at all, just do a normal diff
return None
# Helpers
def rich_repr(o: Any) -> str:
string_io = io.StringIO()
rich.console.Console(
file=string_io,
width=WIDTH // 2,
tab_size=4,
no_color=True,
highlight=False,
log_time=False,
log_path=False,
).print(o)
string_io.seek(0)
return string_io.getvalue()
def replace_mocked_fields(left: Any, right: Any) -> None:
keys: set[str] | range
if is_dataclass(left) and is_dataclass(right):
left = left.__dict__
right = right.__dict__
keys = left.keys() & right.keys()
elif isinstance(left, list) and isinstance(right, list):
keys = range(min(len(left), len(right)))
else:
return
for key in keys:
if right[key] is mock.ANY:
left[key] = mock.ANY
else:
replace_mocked_fields(left[key], right[key])
All your tests should run on the train, right? Adding the following will prevent network calls by default and allow
them only when provided with the allow_network_calls
fixture.
@pytest.fixture(scope="class", autouse=True)
def stop_network_calls(_get_event_loop):
mpatch = MonkeyPatch()
def _socket(*_, **__):
raise Exception("stop making network calls!")
yield mpatch
mpatch.undo()
@pytest.fixture
def allow_network_calls(monkeypatch):
monkeypatch.setattr(socket, "socket", original_socket)
This adds a new function to the debugger - ppp
- that can pretty print Postgres SQLAlchemy queries. The
use case is: slap an assert False
after constructing a SQLAlchemy query in the application code, run
the tests with --pdb
, then ppp my_query
, copy paste into psql
.
class LiteralCompiler(postgresql.psycopg2.PGCompiler_psycopg2):
# see also https://stackoverflow.com/a/9898141/4865874
def visit_bindparam(
self,
bindparam,
within_columns_clause=False,
literal_binds=False,
**kwargs,
):
return super(LiteralCompiler, self).render_literal_bindparam(
bindparam,
within_columns_clause=within_columns_clause,
literal_binds=literal_binds,
**kwargs,
)
def render_literal_value(self, value, type_):
if isinstance(value, str):
value = value.replace("'", "''")
return f"'{value}'"
elif isinstance(value, UUID):
return f"'{value}'"
elif value is None:
return "NULL"
elif isinstance(value, (float, int, Decimal)):
return repr(value)
elif isinstance(value, datetime.datetime):
return f"'{value.isoformat()}'"
elif isinstance(value, enum.Enum):
return f"'{value.value}'"
else:
raise NotImplementedError(
f"Don't know how to literal-quote value {value}"
)
def pp_sql(qry: sql.expression.ClauseElement | Query) -> str:
if isinstance(qry, Query):
qry = qry.statement
qry = cast(sql.expression.ClauseElement, qry)
compiler = LiteralCompiler(postgresql.psycopg2.dialect(), qry)
raw_sql = compiler.process(qry)
indented = sqlparse.format(
raw_sql,
reindent=True,
keyword_case="upper",
indent_width=4,
indent_tabs=False,
wrap_after=20,
)
return f"<Raw SQL query:>\n{indented}"
def ppp(self: Any, arg: Any) -> str | None:
# try get the value from the current scope
try:
obj = self._getval(arg)
except Exception:
return None
# if it looks like a SQL query, try format it
if isinstance(obj, (sql.expression.ClauseElement, Query)):
try:
return pp_sql(obj)
except Exception:
pass
# else try nicely pprint it
try:
return rich_repr(obj)
except Exception as e:
return f">>> Failed to pretty print {obj} with exception {e}"
# call this at the end of plugin.py
def _set_up_prettyprinter() -> None:
# adds `ppp` to the debugger
pdb.Pdb.do_ppp = lambda self, arg: print(ppp(self, arg)) # type: ignore
By convention, tests will often DROP
and CREATE
tables between tests. Normally, it should be
sufficient and faster to just delete the data (in the correct table order) and reset all the sequences. See also.
def clean_tables(session, sqlalchemy_base):
tables = [t for t in reversed(sqlalchemy_base.metadata.sorted_tables)]
for table in tables:
session.execute(table.delete())
# reset all the sequences
sql = (
"SELECT sequencename FROM pg_sequences "
"WHERE schemaname IN (SELECT current_schema())"
)
for [sequence] in session.execute(text(sql)):
session.execute(text(f"ALTER SEQUENCE {sequence} RESTART WITH 1"))
session.commit()
This is a bit of a bodge under the hood, should really use ast rather than regex...
T = TypeVar("T")
@dataclass
class _MonkeyPatchSetAttr(Generic[T]):
monkeypatch: Any
module: Any
attr: str
def to(self, to: T) -> None:
self.monkeypatch.setattr(self.module, self.attr, to)
@dataclass
class MonkeyPatch:
monkeypatch: Any
def __call__(self, from_: T) -> _MonkeyPatchSetAttr[T]:
call_site = inspect.stack()[1]
assert call_site.code_context is not None
code: str = call_site.code_context[0]
match = re.match(r".+patch\(([\w+.]+)\)", code)
assert match is not None
module_name, _, attr = match.groups()[0].rpartition(".")
module = eval(module_name, call_site.frame.f_globals, call_site.frame.f_locals)
return _MonkeyPatchSetAttr(self.monkeypatch, module, attr)
@pytest.fixture
def patch(monkeypatch: Any) -> Iterator[MonkeyPatch]:
yield MonkeyPatch(monkeypatch)