Numpy and mypy type checking
Recently I've been looking at the situation with typing support in Python after a few years away from it, the other day I found out how to get type stubs for SQLAlchemy to work with mypy. Inspired by this win I started looking around for what type checking was supported with other libraries I use frequently.
Back when I was working on Persephone I particularly lamented the lack of type-stubs for numpy and tensorflow. We made some of our own type stubs for Tensorflow while working on the project because it helped us track down some bugs and ensure we didn't introduce typing bugs. Unfortunately that was a lot of work and it was possible for the 3rd party library to get out of sync with our type stubs over time. This was time consuming but was entirely worth it for what we were doing.
During this time a 3rd party project called
numpy-stubs was created to provide type stubs for the
numpy library. These made it much easier to use type checkers like
mypy with the
To my delight I found that type stubs are now part of the
To use them you have to make sure you have a recent enough version of numpy, type stubs landed in version
1.20. So now you should not be using
numpy-stubs since this project was merged into
To enable mypy to check numpy code properly you need to enable the mypy plugin, to do so you have to add a line to
[mypy] plugins = numpy.typing.mypy_plugin
Then to annotate some functions you need to make sure you use the right types, because of the way in which numpy types work this is a bit involved/annoying.
Consider the following code:
import numpy as np def sine_values(data): """Apply the sine function to data""" return np.sin(data)
The return type is fairly easy to annotate here:
def sine_values(data) -> np.ndarray: """Apply the sine function to data""" return np.sin(data)
The parameter type is not so easy to annotate depending on the usage. This is because you can call the code like this:
>>> d1 = np.array([1,2,3,4]) >>> s1 = sine_values(d1) >>> print(s1, d1) [ 0.84147098 0.90929743 0.14112001 -0.7568025 ] [1 2 3 4]
but you can also call it like this:
>>> d2 = [1,2,3,4] >>> s2 = sine_values(d2) >>> print(d2, s2) [1, 2, 3, 4] [ 0.84147098 0.90929743 0.14112001 -0.7568025 ]
To get around this "flexibility" in the accepted types you could use a
Union type but it's easier just to use the type provided from numpy:
import numpy import numpy.typing as npt def sine_values(data: npt.ArrayLike) -> np.ndarray: """Apply the sine function to data""" return np.sin(data)
See the numpy typing docs for more.