Some time ago, I watched a video named “The Absolute Best Intro to Monads For Software Engineers” on YouTube [link] that I instantly shared on the CQ4DS discord [link].
It is indeed a great introduction to monads, an otherwise notoriously inaccessible concept. Unfortunately, the code itself is in TypeScript. Never mind, it’s an excellent opportunity for a blog post. I can also practice some niche typehints that we recently added to our Code Quality for Data Scientists training course. Get in touch for more information on this at: https://hypergolic.co.uk/contact/
There were other attempts to explain monads like “Railway Oriented Programming” by Scott Wlaschin [link] (again, in F#), but I haven’t found anything like this in Python.
Setup
The first step is to create a decent working environment:
poetry init -n
poetry env use 3.11
poetry add ruff black pydantic pytest mypy
Set the Python version in pyproject.toml to (this is a bad default in poetry):
python = "~3.11"
You can run the type checking and the tests with:
poetry run python -m mypy . && poetry run python -m pytest
I also just copied a bunch of settings for black, ruff and mypy from FastAPI and Pydantic. No point reinventing the wheel. They must have thought through why they use these settings and not others.
[tool.black]
skip-string-normalization = true
line-length = 120
[tool.ruff]
# Same as Black.
line-length = 120
exclude = ["jupyter_notebook_config.py"]
select = [
"E", # pycodestyle errors (settings from FastAPI, thanks, @tiangolo!)
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"C", # flake8-comprehensions
"B", # flake8-bugbear
]
ignore = [
"E501", # line too long, handled by black
"C901", # too complex
]
[tool.ruff.isort]
order-by-type = true
relative-imports-order = "closest-to-furthest"
extra-standard-library = ["typing"]
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
known-first-party = []
[tool.pytest.ini_options]
pythonpath = [
"."
]
addopts = "-ra -q"
testpaths = [
"tests",
]
[tool.mypy]
plugins = [
"pydantic.mypy"
]
ignore_missing_imports = true # (settings from Pydantic, thanks, @samuel_colvin!)
follow_imports = "skip"
warn_redundant_casts = true
warn_unused_ignores = true
disallow_any_generics = true
check_untyped_defs = true
no_implicit_reexport = true
# for strict mypy: (this is the tricky one :-))
disallow_untyped_defs = true
[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true
After this much preparation, let’s start with the actual exercise.
Version 1: The basics
Instead of defining what monads are and how they are implemented, I just go with the flow of the video.
Let’s take two small functions (note the typehints):
def square_v1(x: int) -> int:
return x * x
def add_one_v1(x: int) -> int:
return x + 1
And test them with a small test:
from unittest import TestCase
class TestFirst(TestCase):
def test_first(self) -> None:
self.assertEqual(add_one_v1(square_v1(2)), 5)
Nothing complicated so far.
But let’s say we want to log at the same time the order of operations for record-keeping purposes. This is a typical exercise if you have to audit your operations.
Version 2: Logs
This is not that difficult; we just create a pydantic type with the right fields:
from pydantic import BaseModel
class IntWithLogs(BaseModel):
result: int
logs: list[str]
And modify the original version of the functions to use this type:
def square_v2(x: IntWithLogs) -> IntWithLogs:
return IntWithLogs(
result=x.result * x.result,
logs=x.logs + [f'Squared {x.result} to get {x.result * x.result}'],
)
def add_one_v2(x: IntWithLogs) -> IntWithLogs:
return IntWithLogs(
result=x.result + 1,
logs=x.logs + [f'Added 1 to {x.result} to get {x.result + 1}'],
)
But to use these with values we have before (ints), we need a new function that creates these IntWithLogs classes. Something like:
def wrap_with_logs(x: int) -> IntWithLogs:
return IntWithLogs(
result=x,
logs=[],
)
Whenever we want to use the v2 version of our functions on an integer, we use this helper function to make it compatible.
We can also test these with:
class TestSecond(TestCase):
def test_square(self) -> None:
expected = IntWithLogs(
result=16,
logs=[
'Squared 2 to get 4',
'Squared 4 to get 16',
],
)
self.assertEqual(square_v2(square_v2(wrap_with_logs(2))), expected)
Enter the monad
These look fine. But we can refactor this into a better version. Currently, the functions are concatenating their own logs, but it would be better if they only had to deal with their own logs and someone else would do the concatenation. In this way, the functions themselves would remain atomic. Only the relevant information would be accessible inside the function body.
Let’s write the following “run_with_logs” function:
def run_with_logs(input: IntWithLogs, transform: Callable[[int], IntWithLogs]) -> IntWithLogs:
new_int_with_logs = transform(input.result)
return IntWithLogs(
result=new_int_with_logs.result,
logs=input.logs + new_int_with_logs.logs,
)
Note the “Callable[[int], IntWithLogs]” type, an indication that we expect a function that receives an integer and returns a result with a log.
The function takes the value out of the received variable. Passes it to the received function (that takes an integer as an argument), and the resulting new IntWithLogs can be concatenated with the original one.
This requires a slight rewrite of the original “square” and “add_one” functions:
def square_v3(x: int) -> IntWithLogs:
return IntWithLogs(
result=x * x,
logs=[f'Squared {x} to get {x * x}'],
)
def add_one_v3(x: int) -> IntWithLogs:
return IntWithLogs(
result=x + 1,
logs=[f'Added 1 to {x} to get {x + 1}'],
)
As you can see, they now take a primitive type and return a pydantic one. Internally, they only return the log of their own activity and are not concerned about what will happen with the logs later (i.e. concatenated together).
class TestThird(TestCase):
def test_square(self) -> None:
actual = run_with_logs(run_with_logs(wrap_with_logs(5), add_one_v3), square_v3)
expected = IntWithLogs(
result=36,
logs=[
'Added 1 to 5 to get 6',
'Squared 6 to get 36',
],
)
self.assertEqual(actual, expected)
And this is pretty much it. A functioning monad. But what is the monad exactly?
To get a monad, you need the following components:
Base type: int
“Wrapper” type: IntWithLogs
A “Wrapper” function: wrap_with_logs (also known as “return”, “pure”, or “unit”, according to the video; don’t ask me why…)
A “Run” function: run_with_logs (also known as “bind”, “flatMap”, or “>>=”)
These together form a monad.
Let’s put these together:
After watching the video, I thought these could be part of the same class and form a single unit, so I refactored it like this:
from typing import Callable, Self
class IntWithLogs(BaseModel):
result: int
logs: list[str]
@classmethod
def wrap(cls, x: int) -> Self:
return cls(
result=x,
logs=[],
)
def run(self, transform: Callable[[int], Self]) -> Self:
new_int_with_logs = transform(self.result)
return IntWithLogs(
result=new_int_with_logs.result,
logs=self.logs + new_int_with_logs.logs,
)
Here, the changes are pretty straightforward. “wrap()” is a classmethod and “run()” operates on the object’s own value but otherwise identical to the previous version.
The interesting bit is the “Self” type that is apparently equivalent to:
Self = TypeVar("Self", bound="IntWithLogs")
See more on this here: https://docs.python.org/3/library/typing.html#typing.Self:
The tests will look a bit different due to the more compact use:
class TestFourth(TestCase):
def test_self_run_wrap(self) -> None:
expected = IntWithLogs(
result=36,
logs=[
'Added 1 to 5 to get 6',
'Squared 6 to get 36',
],
)
self.assertEqual(
IntWithLogs.wrap(x=5).run(add_one_v3).run(square_v3),
expected,
)
Make it generic!
Now that we have gone into this much effort, we might as well examine how to turn the class into a generic one. Also, take care of the typehints, which will not be easy because it should be a generic _pydantic_ class.
Luckily, pydantic has excellent documentation on this here: https://docs.pydantic.dev/latest/usage/models/#generic-models
from typing import Callable, Generic, Self, TypeVar
T = TypeVar('T')
class WithLogs(BaseModel, Generic[T]):
result: T
logs: list[str]
@classmethod
def wrap(cls, x: T) -> Self:
return cls(
result=x,
logs=[],
)
def run(self, transform: Callable[[T], Self]) -> Self:
new_result = transform(self.result)
return WithLogs(
result=new_result.result,
logs=self.logs + new_result.logs,
)
Well, this wasn’t too hard. Just replace “int” with “T” everywhere. The tests use the new class and specifying the internal type:
class TestFifth(TestCase):
def test_self_run_wrap(self) -> None:
expected = WithLogs[int](
result=36,
logs=[
'Added 1 to 5 to get 6',
'Squared 6 to get 36',
],
)
self.assertEqual(
WithLogs[int].wrap(x=5).run(add_one_v3).run(square_v3),
expected,
)
Summary
I think the easiest to summarise this is with the signatures of the class:
class WrapperType(BaseModel, Generic[T]):
inner: T
something_else: list[str]
@classmethod
def wrap(cls, x: T) -> Self: ...
# Gets the inner type, returns the wrapper type
def run(self, transform: Callable[[T], Self]) -> Self: ...
# Gets the transformer, returns the wrapper type
I hope you enjoyed this journey regarding monads and typehints. I think this is an excellent first step to get familiar with the concept and feel free to comment if you have any feedback. Otherwise, subscribe if you would like to get notified of similar content: