Rewriting haskell-spsa to be monadic
A couple weeks ago, I wrote the initial code for haskell-spsa, which was a port of golang-spsa. In my first post on it, I talked about the simplicity of the haskell implementation whilst also being fast. In my post at make.rafflecopter.com, I talked about how porting a known package allowed me to learn how to build a haskell package the haskell way.
Last Sunday, I decided to take a crack at simplifying the API for creating and running SPSA into something easier to use and understand. I have been reading Real World Haskell (a book for which I cannot give enough praise), and I figured I could use some apply some of my knowledge to this task.
Let’s take a look at the old API for creating an SPSA object.
I wrote this code, and I can’t tell you exactly what its doing without digging through the source. Theres magic numbers being thrown everywhere with no regard! Now, lets see the new API:
Here we have what appears to be a much more imperative API. But, even though it is more verbose, it is more self-explanatory as to what each line is doing. We are constructing an SPSA configuration from the ground up, setting the loss function, stop criteria, perturbation vector, and tuning the gain sequences (a
and c
are the primary tunable parameters). Better yet, theres no IO monad buried deep down (but you’ll need to pass in a random seed). How is this implemented, a question you might be asking.
By using the State monad, we can apply incremental updates to the implicit state of an SPSA
data type as we construct it. That way, we can build an SPSA configuration in many different complex and complicated ways and keep a sane API. Now, let’s see how this has affected the internals.
We can see in that piece of code that the core optimization functions are more easily taken in without a bunch of where
clauses. runSPSA'
is a simple tail-recursive function which runs a singleIteration
until checkStop t t'
returns true. All of this is done without explicitly passing around an SPSA
object thanks to the State
monad.
Finally, we have the exported StateSPSA
runner, which runs the SPSA instance with a given guess and returns the final guess.
For more details, check out the github and the examples, or download it from hackage and start hacking!