开发者

Creating custom connectivity in PyBrain neural networks

开发者 https://www.devze.com 2023-02-27 01:50 出处:网络
I want to create an artificial neural network (in PyBrain) that follows the following layout: However, I cannot find the proper way to achieve this. The only option that I see in the documentation

I want to create an artificial neural network (in PyBrain) that follows the following layout:

Creating custom connectivity in PyBrain neural networks

However, I cannot find the proper way to achieve this. The only option that I see in the documentation is the way to create fully connected layers, whic开发者_如何学JAVAh is not what I want: I want some of my input nodes to be connected to the second hidden layer and not to the first one.


The solution is to use the connection type of your choice, but with slicing parameters: inSliceFrom, inSliceTo, outSliceFrom and outSliceTo. I agree the documentation should mention this, so far it's only in the Connection class' comments.

Here is example code for your case:

#create network and modules
net = FeedForwardNetwork()
inp = LinearLayer(9)
h1 = SigmoidLayer(2)
h2 = TanhLayer(2)
outp = LinearLayer(1)
# add modules
net.addOutputModule(outp)
net.addInputModule(inp)
net.addModule(h1)
net.addModule(h2)
# create connections
net.addConnection(FullConnection(inp, h1, inSliceTo=6))
net.addConnection(FullConnection(inp, h2, inSliceFrom=6))
net.addConnection(FullConnection(h1, h2))
net.addConnection(FullConnection(h2, outp))
# finish up
net.sortModules()


An alternative way to the one suggested by schaul is to use multiple input layers.

#create network
net = FeedForwardNetwork()

# create and add modules
input_1 = LinearLayer(6)
net.addInputModule(input_1)
input_2 = LinearLayer(3)
net.addInputModule(input_2)
h1 = SigmoidLayer(2)
net.addModule(h1)
h2 = SigmoidLayer(2)
net.addModule(h2)
outp = SigmoidLayer(1)
net.addOutputModule(outp)

# create connections
net.addConnection(FullConnection(input_1, h1))
net.addConnection(FullConnection(input_2, h2))
net.addConnection(FullConnection(h1, h2))
net.addConnection(FullConnection(h2, outp))

net.sortModules()
0

精彩评论

暂无评论...
验证码 换一张
取 消