Sunday, January 18, 2009

Thoughts on Understanding Neural Networks

Lately, I've been thinking quite a bit about neural networks. In particular, I've been wondering whether it is actually possible to understand them. As a note, this posting assumes that the reader has some understanding of neural networks. Of course, we at Data Miners, heartily recommend our book Data Mining Techniques for Marketing, Sales, and Customer Relationship Management for introducing neural networks (as well as a plethora of other data mining algorithms).

Let me start with a picture of a neural network. The following is a simple network that takes three inputs and has two nodes in the hidden layer:

Note that this structure of the network explains what is really happening. The "input layer" (the first layer connected to the inputs) standardizes the inputs. The "output layer" (connect to the output) is doing a regression or logistic regression, depending on whether the target is numeric or binary. The hidden layers are actually doing a mathematical operation as well. This could be the logistic function; more typically, though it is the hyperbolic tangent. All of the lines in the diagram have weights on them. Setting these weights -- plus a few others not shown -- is the process of training the neural network.

The topology of the neural network is specifically how SAS Enterprise Miner implements the network. Other tools have similar capabilities. Here, I am using SAS EM for three reasons. First, because we teach a class using this tool, I have pre-built neural network diagrams. Second, the neural network node allows me to score the hidden units. And third, the graphics provide a data-colored scatter plot, which I use to describe what's happening.

There are several ways to understand this neural network. The most basic way is "it's a black box and we don't need to understand it." In many respects, this is the standard data mining viewpoint. Neural networks often work well. However, if you want a technique that let's you undersand what it is doing, then choose another technique, such as regression or decision trees or nearest neighbor.

A related viewpoint is to write down the equation for what the network is doing. Then point out that this equation *is* the network. The problem is not that the network cannot explain what it is doing. The problem is that we human beings cannot understand what it is saying.

I am going to propose two other ways of looking at the network. One is geometrically. The inputs are projected onto the outputs of the hidden layer. The results of this projection are then combined to form the output. The other method is, for lack of a better term, "clustering". The hidden nodes actually identify patterns in the original data, and one hidden node usually dominates the output within a cluster.

Let me start with the geometric interpretation. For the network above, there are three dimensions of inputs and two hidden nodes. So, three dimensions are projected down to two dimensions.

I do need to emphasize that these projections are not the linear projections. This means that they are not described by simple matrices. These are non-linear projections. In particular, a given dimension could be stretched non-uniformly, which further complicates the situation.

I chose two nodes in the hidden layer on purpose, simply because two dimensions are pretty easy to visualize. Then I went and I tried it on a small neural network, using Enterprise Miner. The next couple of pictures are scatter plots made with EM. It has the nice feature that I can color the points based on data -- a feature sadly lacking from Excel.

The following scatter plot shows the original data points (about 2,700 of them). The positions are determined by the outputs of the hidden layers. The colors show the output of the network itself (blue being close to 0 and red being close to 1). The network is predicting a value of 0 or 1 based on a balanced training set and three inputs.

Hmm, the overall output is pretty much related to the H1 output rather than the H2 output. We see this becasuse the color changes primarily as we move horizontally across the scatter plot and not vertically. This is interesting. It means that H2 is contributing little to the network prediction. Under these particular circumstances, we can explain the output of the neural network by explaining what is happening at H1. And what is happening at H1 is a lot like a logistic regression, where we can determine the weights of different variables going in.

Note that this is an approximation, because H2 does make some contribution. But it is a close approximation, because for almost all input data points, H1 is the dominant node.

This pattern is a consequence of the distribution of the input data. Note that H2 is always negative and close to -1, whereas H1 varies from -1 to 1 (as we would expect, given the transfer function). This is because the inputs are always positive and in a particular range. The inputs do not result in the full range of values for each hidden node. This fact, in turn, provides a clue to what the neural network is doing. Also, this is close to a degenerate case because one hidden unit is almost always ignored. It does illustrate that looking at the outputs of the hidden layers are useful.

This suggests another approach. Imagine the space of H1 and H2 values, and further that any combination of them might exist (do remember that because of the transfer function, the values actually are limited to the range -1 to 1). Within this space, which node dominates the calculation of the output of the network?

To answer this question, I had to come up with some reasonable way to compare the following values:
  • Network output: exp(bias + a1*H1 + a2*H2)
  • H1 only: exp(bias + a1*H1)
  • H2 only: exp(bias + a2*H2)
Let me give an example with numbers. For the network above, we have the following when H1 and H2 are both -1:
  • Network output: 0.9994
  • H1 only output: 0.9926
  • H2 only output: 0.9749
To calculate the contribution of H1, I use the ratio of the sums of the squares of the differences, as in the following example for H1:
  • H1 contribution: (0.9994 - 0.9926)^2 / ((0.9994 - 0.9926)^2 + (0.9994 - 0.9749)^2)
The following scatter plot shows the regions where H1 dominates the overall prediction of the network using this metric (red is H1 is dominant; blue is H2 is dominant):


There are four regions in this scatter plot, defined essentially by the intersection of two lines. In fact, each hidden node is going to add another line on this chart, generating more regions. Within each region, one node is going to dominate. The boundaries are fuzzy. Sometimes this makes no difference, because the output on either side is the same; sometimes it does make a difference.

Note that this scatter plot assumes that the inputs can generate all combinations of values from the hidden units. However, in practice, this is not true, as shown on the previous scatter plot, which essentially covers only the lowest eights of this one.

With the contribution metric, we can then say that for different regions in the hidden unit space, different hidden units dominate the output. This is essentially saying that in different areas, we only need one hidden unit to determine the outcome of the network. Within each region, then, we can identify the variables used by the hidden units and say that they are determining the outcome of the network.

This idea leads to a way to start to understand standard multilayer perceptron neural networks, at least in the space of the hidden units. We can identify the regions where particular hidden units dominate the output of the network. Within each region, we can identify which variables dominate the output of that hidden unit. Perhaps this explains what is happening in the network, because the input ranges limit the outputs only to one region.

More likely, we have to return to the original inputs to determine which hidden unit dominates for a given combination of inputs. I've only just started thinking about this idea, so perhaps I'll follow up in a later post.

--gordon

8 comments:

  1. nice!

    I tried to do something like this a few years ago. I like yours more. The best idea I could come up with was to use colour and tansparency to represent the positive or negative effect and strengths of the NN weights and hidden nodes.

    I went to the effort of creating a vb.net application that loads PMML into data grids and then from the data grids builds the graphic. You might be able to use the data grids for your own graphics.

    It should work with neural net saved as PMML from SAS EM. I tested it a little bit with Clementine and SAS PMML but only with simple single hidden layer NN's.

    See;
    http://www.kdkeys.net/forums/thread/6495.aspx

    You can simply drag and drop the NN PMML onto the .exe and it will display the graph. I supplied the code and everything. Use as you wish. Completely free!

    I'm not sure its much use, but I did it for fun to learn a bit of VB.NET (I'm certainly no programmer :)

    Cheers

    Tim

    ReplyDelete
  2. I added a few pictures to my blog in case you don't want to run the vb.net executable etc.

    http://timmanns.blogspot.com/2009/01/re-thoughts-on-understanding-neural.html

    cheers

    Tim

    ReplyDelete
  3. Gordon,
    I spent 4 years of my life trying to understand how neural networks did what they did during my doctorate. It was only after I started drawing plots of the outputs of the hidden neurons that the penny finally dropped and everything became crystal clear.

    What I would recommend is start off with simple made up funcions such as y=x^2. It quickly becomes apparent why 2 sigmoidal neurons are required. Then try y=log(x), and then a function such as y=x^2 + log(x).

    The ability to 'prune' a network to force particuar inputs to be procesed by selected hidden neurons is a must, you can then essentially decompose the model into y = f(x1,x2) + f(x3) + f(x4) etc..
    This gives you a great ability to extract underlying trends from data. For example, my doctorate was electric load prediction and this method enables the load to be decomposed into f(temperature) + f(time of day) + f(day of week) etc...
    More recently I extracted the premium of LPG vehicles over Unleaded petrol vehicles over time for cars sold at auction. The results did reflect the difference in the two fuel costs (and govt subsidies to convert your car) over time.

    I wrote a little application in excel that will allow you to view the outputs of the hidden neurons as the network is being trained.

    http://www.philbrierley.com/code/vba.html

    Goto the train tab and select the 'decomposition' graph. It can be informative to watch. The default data is something like y=x^2 but you can paste any data in you want.

    I also write an application called Tiberius that lets you do a similar thing but also lets you manually 'prune' the network to force the form of the model.'

    www.tiberius.biz

    You can see some oldish screenshots of the decomposition at
    www.philbrierley.com/tiberius/slideshow.html
    slide 8 of the y=x^2 demo will give you the idea.

    Hope my work will help you on your quest to understand neural networks. Would be more than willing to link up on skype to talk you through what can be achieved.

    Regards

    Phil Brierley.

    ReplyDelete
  4. I would also point out the Neural Networks chapter in Tom Mitchell's "Machine Learning" book. He goes through some cases of ANNs applied to image recognition, and these are quite telling as to what hidden nodes *could* learn.
    Max Khesin.
    (Regards!)

    ReplyDelete
  5. I remember a publication by Le Cun in the early 90s I think that described how a neural network doing OCR behaved--if I remember correctly, some HL neurons were effectively horizontal line detectors, other vertical line detectors, others edge detectors (this one generated a l ot of excitement because the human eye does edge detection, among many other things!).

    However, I think too that there is value in understanding overall trends in a neural network too (overall sensitivity), much like one sees in many neural net applications like in Clementine or Neuralware Predict. These are akin to examining coefficients in a regression model in that they give overall average trends but not much else.

    ReplyDelete
  6. Hi, one way I find useful to understand neural networks is to look at various applications from a time series regression perspective.

    I imagine a linear regression as just being the best fit to a noisy line. If you think of any other scatterplot that might be uniquely represented by a non-linear relationship, a neural network can usually find it.

    I wrote up a brief tutorial on learning financial time series that you may find of interest at:

    http://intelligenttradingtech.blogspot.com/

    ReplyDelete
  7. Here is a link to the best visualization of a neural network that I've ever seen.
    http://webgl-ann-experiment.appspot.com/

    ReplyDelete
  8. This is really cool! I am just a student right and have some experience with neural netowrks, but I have always thought of them just like you described at the beginning of your article: as a black box, or like a mathematical formula. More so like a collection of mathematical formulas rather than a black box, but anyways…
    I know this post is pretty old, but I got so excited about it that I’m just going to post it anyways, even thought it might have no scientific value for you… I honestly have never even thought of any way of visualizing a neural network, or just purely analyzing which hidden unit contributes the most to the output. I am not really familiar with all of the current advancements in neural network metalearning, but this seems like something like what you described may help researchers and data scientists explain their findings more clearly. Not only that, the researchers themselves may better understand what’s going on with their learning model if they analyze it by what you called “clustering” (for the lack of a better word :) ).
    Thank you so much for this post!

    ReplyDelete

Your comment will appear when it has been reviewed by the moderators.