2024-02-11

Helpful pytest plugins

These are some helpful pytest plugins that I've collected over the years. They're here as code rather than as a library because it's useful to be able to modify them for a particular domain/codebase. Some of the hacks are pretty ugly as they utilise fairly deep implementation details.

Setup

Imports for the subsequent code, external dependencies are: icdiff, sqlparse, rich, sqlalchemy.
import contextlib
import datetime
import enum
import fcntl
import importlib
import io
import os
import pdb
import socket
import struct
import termios
from dataclasses import is_dataclass
from socket import socket as original_socket
from decimal import Decimal
from typing import Any, cast
from unittest import mock
from uuid import UUID

import icdiff
import pytest
import rich.console
import sqlparse
from sqlalchemy import create_engine, event, sql, text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Query, Session
from _pytest.monkeypatch import MonkeyPatch

We're going to make a file: some/dir/plugin.py, you can make all the fixtures/plugins available by adding the following to a conftest.py:

pytest_plugins = ["some.dir.plugin"]

From now, all the code will be assumed to be in plugin.py.

Nicer comparisons

Pytest's comparisons often leave a lot to be desired. Adding pytest_assertrepr_compare overrides pytest's compare function, if the function returns None, it reverts to the default comparison. This is basically just a minimal version of pytest-icdiff, it's often handy to add domain specific pretty printers to aid debugging, so its useful in "raw" form.

Example output:

Code:

def get_terminal_width() -> int:  # from: https://gist.github.com/jtriley/1108174
    fd = os.open(os.ctermid(), os.O_RDONLY)
    cr = struct.unpack("hh", fcntl.ioctl(fd, termios.TIOCGWINSZ, "1234"))  # type: ignore
    os.close(fd)
    return int(cr[1])

try:
    WIDTH = get_terminal_width() - 10
except Exception:
    WIDTH = 80

def pytest_assertrepr_compare(config: Any, op: str, left: Any, right: Any) -> list[str] | None:
    very_verbose = config.option.verbose >= 2
    if not very_verbose:
        return None
    if op != "==":
        return None
    try:
        if abs(left + right) < 100:
            return None
    except TypeError:
        pass

    replace_mocked_fields(left, right)
    try:
        if isinstance(left, str) and isinstance(right, str):
            pretty_left = left.splitlines()
            pretty_right = right.splitlines()
        else:
            pretty_left = rich_repr(left).splitlines()
            pretty_right = rich_repr(right).splitlines()
        differ = icdiff.ConsoleDiff(cols=WIDTH, tabsize=4)
        icdiff_lines = differ.make_table(pretty_left, pretty_right, context=True, numlines=10)
        return (
            ["equals failed"]
            + ["ACTUAL".center(WIDTH // 2 - 1) + "|" + "EXPECTED".center(WIDTH // 2)]
            + ["-" * WIDTH]
            + [icdiff.color_codes["none"] + line for line in icdiff_lines]
        )
    except Exception:  # if it breaks at all, just do a normal diff
        return None

# Helpers

def rich_repr(o: Any) -> str:
    string_io = io.StringIO()
    rich.console.Console(
        file=string_io,
        width=WIDTH // 2,
        tab_size=4,
        no_color=True,
        highlight=False,
        log_time=False,
        log_path=False,
    ).print(o)
    string_io.seek(0)
    return string_io.getvalue()

def replace_mocked_fields(left: Any, right: Any) -> None:
    keys: set[str] | range
    if is_dataclass(left) and is_dataclass(right):
        left = left.__dict__
        right = right.__dict__
        keys = left.keys() & right.keys()
    elif isinstance(left, list) and isinstance(right, list):
        keys = range(min(len(left), len(right)))
    else:
        return

    for key in keys:
        if right[key] is mock.ANY:
            left[key] = mock.ANY
        else:
            replace_mocked_fields(left[key], right[key])

Prevent network calls

All your tests should run on the train, right? Adding the following will prevent network calls by default and allow them only when provided with the allow_network_calls fixture.

@pytest.fixture(scope="class", autouse=True)
def stop_network_calls(_get_event_loop):
    mpatch = MonkeyPatch()

    def _socket(*_, **__):
        raise Exception("stop making network calls!")

    yield mpatch
    mpatch.undo()


@pytest.fixture
def allow_network_calls(monkeypatch):
    monkeypatch.setattr(socket, "socket", original_socket)

Prettier pretty print with Postgres SQL support

This adds a new function to the debugger - ppp - that can pretty print Postgres SQLAlchemy queries. The use case is: slap an assert False after constructing a SQLAlchemy query in the application code, run the tests with --pdb, then ppp my_query, copy paste into psql.

class LiteralCompiler(postgresql.psycopg2.PGCompiler_psycopg2):
    # see also https://stackoverflow.com/a/9898141/4865874
    def visit_bindparam(
        self,
        bindparam,
        within_columns_clause=False,
        literal_binds=False,
        **kwargs,
    ):
        return super(LiteralCompiler, self).render_literal_bindparam(
            bindparam,
            within_columns_clause=within_columns_clause,
            literal_binds=literal_binds,
            **kwargs,
        )

    def render_literal_value(self, value, type_):
        if isinstance(value, str):
            value = value.replace("'", "''")
            return f"'{value}'"
        elif isinstance(value, UUID):
            return f"'{value}'"
        elif value is None:
            return "NULL"
        elif isinstance(value, (float, int, Decimal)):
            return repr(value)
        elif isinstance(value, datetime.datetime):
            return f"'{value.isoformat()}'"
        elif isinstance(value, enum.Enum):
            return f"'{value.value}'"
        else:
            raise NotImplementedError(
                f"Don't know how to literal-quote value {value}"
            )


def pp_sql(qry: sql.expression.ClauseElement | Query) -> str:
    if isinstance(qry, Query):
        qry = qry.statement

    qry = cast(sql.expression.ClauseElement, qry)
    compiler = LiteralCompiler(postgresql.psycopg2.dialect(), qry)
    raw_sql = compiler.process(qry)

    indented = sqlparse.format(
        raw_sql,
        reindent=True,
        keyword_case="upper",
        indent_width=4,
        indent_tabs=False,
        wrap_after=20,
    )
    return f"<Raw SQL query:>\n{indented}"


def ppp(self: Any, arg: Any) -> str | None:
    # try get the value from the current scope
    try:
        obj = self._getval(arg)
    except Exception:
        return None

    # if it looks like a SQL query, try format it
    if isinstance(obj, (sql.expression.ClauseElement, Query)):
        try:
            return pp_sql(obj)
        except Exception:
            pass

    # else try nicely pprint it
    try:
        return rich_repr(obj)
    except Exception as e:
        return f">>> Failed to pretty print {obj} with exception {e}"

# call this at the end of plugin.py
def _set_up_prettyprinter() -> None:
    # adds `ppp` to the debugger
    pdb.Pdb.do_ppp = lambda self, arg: print(ppp(self, arg))  # type: ignore

Quickly clean Postgres tables

By convention, tests will often DROP and CREATE tables between tests. Normally, it should be sufficient and faster to just delete the data (in the correct table order) and reset all the sequences. See also.

def clean_tables(session, sqlalchemy_base):
    tables = [t for t in reversed(sqlalchemy_base.metadata.sorted_tables)]
    for table in tables:
        session.execute(table.delete())

    # reset all the sequences
    sql = (
        "SELECT sequencename FROM pg_sequences "
        "WHERE schemaname IN (SELECT current_schema())"
    )
    for [sequence] in session.execute(text(sql)):
        session.execute(text(f"ALTER SEQUENCE {sequence} RESTART WITH 1"))

    session.commit()

Typed monkeypatch

This is a bit of a bodge under the hood, should really use ast rather than regex...

T = TypeVar("T")

@dataclass
class _MonkeyPatchSetAttr(Generic[T]):
    monkeypatch: Any
    module: Any
    attr: str

    def to(self, to: T) -> None:
        self.monkeypatch.setattr(self.module, self.attr, to)

@dataclass
class MonkeyPatch:
    monkeypatch: Any

    def __call__(self, from_: T) -> _MonkeyPatchSetAttr[T]:
        call_site = inspect.stack()[1]
        assert call_site.code_context is not None
        code: str = call_site.code_context[0]
        match = re.match(r".+patch\(([\w+.]+)\)", code)
        assert match is not None
        module_name, _, attr = match.groups()[0].rpartition(".")
        module = eval(module_name, call_site.frame.f_globals, call_site.frame.f_locals)
        return _MonkeyPatchSetAttr(self.monkeypatch, module, attr)

@pytest.fixture
def patch(monkeypatch: Any) -> Iterator[MonkeyPatch]:
    yield MonkeyPatch(monkeypatch)