32 lines
770 B
Lua
32 lines
770 B
Lua
|
local RNN = {}
|
||
|
|
||
|
function RNN.rnn(input_size, rnn_size, n)
|
||
|
|
||
|
-- there are n+1 inputs (hiddens on each layer and x)
|
||
|
local inputs = {}
|
||
|
table.insert(inputs, nn.Identity()()) -- x
|
||
|
for L = 1,n do
|
||
|
table.insert(inputs, nn.Identity()()) -- prev_h[L]
|
||
|
end
|
||
|
|
||
|
local x, input_size_L
|
||
|
local outputs = {}
|
||
|
for L = 1,n do
|
||
|
|
||
|
local prev_h = inputs[L+1]
|
||
|
if L == 1 then x = inputs[1] else x = outputs[L-1] end
|
||
|
if L == 1 then input_size_L = input_size else input_size_L = rnn_size end
|
||
|
|
||
|
-- RNN tick
|
||
|
local i2h = nn.Linear(input_size_L, rnn_size)(x)
|
||
|
local h2h = nn.Linear(rnn_size, rnn_size)(prev_h)
|
||
|
local next_h = nn.Tanh()(nn.CAddTable(){i2h, h2h})
|
||
|
|
||
|
table.insert(outputs, next_h)
|
||
|
end
|
||
|
|
||
|
return nn.gModule(inputs, outputs)
|
||
|
end
|
||
|
|
||
|
return RNN
|