local GRU = {} --[[ Creates one timestep of one GRU Paper reference: http://arxiv.org/pdf/1412.3555v1.pdf ]]-- function GRU.gru(input_size, rnn_size, n) -- there are n+1 inputs (hiddens on each layer and x) local inputs = {} table.insert(inputs, nn.Identity()()) -- x for L = 1,n do table.insert(inputs, nn.Identity()()) -- prev_h[L] end function new_input_sum(insize, xv, hv) local i2h = nn.Linear(insize, rnn_size)(xv) local h2h = nn.Linear(rnn_size, rnn_size)(hv) return nn.CAddTable()({i2h, h2h}) end local x, input_size_L local outputs = {} for L = 1,n do local prev_h = inputs[L+1] if L == 1 then x = inputs[1] else x = outputs[L-1] end if L == 1 then input_size_L = input_size else input_size_L = rnn_size end -- GRU tick -- forward the update and reset gates local update_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h)) local reset_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h)) -- compute candidate hidden state local gated_hidden = nn.CMulTable()({reset_gate, prev_h}) local p2 = nn.Linear(rnn_size, rnn_size)(gated_hidden) local p1 = nn.Linear(input_size_L, rnn_size)(x) local hidden_candidate = nn.Tanh()(nn.CAddTable()({p1,p2})) -- compute new interpolated hidden state, based on the update gate local zh = nn.CMulTable()({update_gate, hidden_candidate}) local zhm1 = nn.CMulTable()({nn.AddConstant(1,false)(nn.MulConstant(-1,false)(update_gate)), prev_h}) local next_h = nn.CAddTable()({zh, zhm1}) table.insert(outputs, next_h) end return nn.gModule(inputs, outputs) end return GRU