I have a neural network written in Java which uses a sigmoid transfer function defined as follows:
private static double sigmoid(double x)
{
return 1 / (1 + Math.exp(-x));
}
开发者_StackOverflow社区
and this is called many times during training and computation using the network. Is there any way of speeding this up? It's not that it's slow, it's just that it is used a lot, so a small optimisation here would be a big overall gain.
For neural networks, you don't need the exact value of the sigmoid function. So you can precalculate 100 values and reuse the value that is closest to your input, or even better (as a comment stated) do an interpolation from the neighbour values.
How you can do this is described in this article (link stolen from the answer of s-lott).
This is the sigmoid function:
As you can see, only values of -10 < x < 10 are interesting at all. And, as another comment stated, the function is symmetric. You only have to store half of the values at all.
Edit: I'm sorry that I showed the wrong graph here. I've corrected it.
If you have a lot of nodes where the value of x is outside the -10..+10 box, you can just omit to calculate those values at all, e.g., like so ..
if( x < -10 )
y = 0;
else if( x > 10 )
y = 1;
else
y = 1 / (1 + Math.exp(-x));
return y;
Of course, this incurs the overhead of the conditional checks for EVERY calculation, so it's only worthwhile if you have lots of saturated nodes.
Another thing worth mentioning is, if you are using backpropagation, and you have to deal with the slope of the function, it's better to compute it in pieces rather than 'as written'.
I can't recall the slope at the moment, but here's what I'm talking about using a bipolar sigmoid as an example. Rather than compute this way
y = (1 - exp(-x)) / (1 + exp(-x));
which hits exp() twice, you can cache up the costly calculations in temporary variables, like so
temp = exp(-x);
y = (1 - temp) / (1 + temp);
There are lots of places to put this sort of thing to use in BP nets.
It's a pretty smooth function, so a lookup and interpolation scheme is likely to be more than sufficient.
When I plot the function over a range of -10 <= x <= 10
, I get five place accuracy at the extremes. Is that good enough for your application?
From a math point of view, I don't see any possibility to optimize it.
精彩评论