Numpy and mypy type checking
Recently I've been looking at the situation with typing support in Python after a few years away from it, the other day I found out how to get type stubs for SQLAlchemy to work with mypy. Inspired by this win I started looking around for what type checking was supported with other libraries I use frequently.
Back when I was working on Persephone I particularly lamented the lack of type-stubs for numpy and tensorflow. We made some of our own type stubs for Tensorflow while working on the project because it helped us track down some bugs and ensure we didn't introduce typing bugs. Unfortunately that was a lot of work and it was possible for the 3rd party library to get out of sync with our type stubs over time. This was time consuming but was entirely worth it for what we were doing.
During this time a 3rd party project called numpy-stubs
was created to provide type stubs for the numpy
library. These made it much easier to use type checkers like mypy
with the numpy
library
To my delight I found that type stubs are now part of the numpy
project.
To use them you have to make sure you have a recent enough version of numpy, type stubs landed in version 1.20
. So now you should not be using numpy-stubs
since this project was merged into numpy
.
To enable mypy to check numpy code properly you need to enable the mypy plugin, to do so you have to add a line to mypy.ini
:
[mypy]
plugins = numpy.typing.mypy_plugin
Then to annotate some functions you need to make sure you use the right types, because of the way in which numpy types work this is a bit involved/annoying.
Consider the following code:
import numpy as np
def sine_values(data):
"""Apply the sine function to data"""
return np.sin(data)
The return type is fairly easy to annotate here:
def sine_values(data) -> np.ndarray:
"""Apply the sine function to data"""
return np.sin(data)
The parameter type is not so easy to annotate depending on the usage. This is because you can call the code like this:
>>> d1 = np.array([1,2,3,4])
>>> s1 = sine_values(d1)
>>> print(s1, d1)
[ 0.84147098 0.90929743 0.14112001 -0.7568025 ] [1 2 3 4]
but you can also call it like this:
>>> d2 = [1,2,3,4]
>>> s2 = sine_values(d2)
>>> print(d2, s2)
[1, 2, 3, 4] [ 0.84147098 0.90929743 0.14112001 -0.7568025 ]
To get around this "flexibility" in the accepted types you could use a Union
type but it's easier just to use the type provided from numpy:
import numpy
import numpy.typing as npt
def sine_values(data: npt.ArrayLike) -> np.ndarray:
"""Apply the sine function to data"""
return np.sin(data)
See the numpy typing docs for more.