Sunday, March 14, 2010

Animated evolution of RNN state-space

The above animation depicts the state space of a discrete-time recurrent neural network (RNN) undergoing neuroevolution. Each new frame represents the discovery of an improvement over the previous best RNN.

The goal is to be able to classify the parity of a binary string, in this case, of length 10. The string is fed sequentially to the RNN, one bit per time step.

A few training examples:

0100110101 -> 1
1110001011 -> 0
0010000010 -> 0

The RNN I used for this particular experiment has a hidden layer of two neurons. This was done so that I could easily plot the state space in two dimensions. The x-axis represents the activation of the first neuron, the y-axis is the activation of the second. Each point within a given frame represents the 'state' of the RNN at the very end of processing a string. The number of points per frame corresponds to the number of training examples. The first few frames appear to have fewer points; this is just overlap. The points have been colored red and blue to denote whether the target output should have been even or odd parity; they are not representative of the actual output of the RNN.

As can be seen, over the course of evolution, the state space of the RNN rearranges itself so that the red and blue points become progressively more linearly separable. On the last frame, it has nearly achieved that goal (as shown below), and classification accuracy on unseen examples is roughly 99%.

For those who are interested, I created these visualizations using a combination of Processing and GIMP


todicus said...

This is really cool! Do you have any idea why the shapes carved by the X & Y values have such fluid dynamic / fractal visual qualities?

Thomas said...

Good question! I think this can be broken down into two sub-questions:

(1) Why is the RNN capable of producing such fractal state-spaces?

(2) Why does the evolutionary process "choose" to make use of these fractal dynamics?

In answer to the first question - I think this is just due to the fact that the output represents the results of a non-linear recurrent process, similar to the way in which you can generate the Mandelbrot set with a very simple recurrence relation.

The second question is (to me) much harder to answer in a definitive way. I suspect that it is not so much a matter of evolution "preferring" to use complex/fractal dynamics, but rather that these sorts of dynamics are overwhelmingly common in the state-spaces of RNN networks, and are thus almost unavoidable.

You might also be interested in this.