use biases of 1.0 for the forget gate
This commit is contained in:
parent
1b05d91e23
commit
0a45fbe230
@ -27,18 +27,24 @@ function LSTM.lstm(input_size, rnn_size, n, dropout)
|
||||
input_size_L = rnn_size
|
||||
end
|
||||
-- evaluate the input sums at once for efficiency
|
||||
local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x)
|
||||
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h)
|
||||
local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L}
|
||||
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L}
|
||||
local all_input_sums = nn.CAddTable()({i2h, h2h})
|
||||
-- decode the gates
|
||||
local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums)
|
||||
sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk)
|
||||
local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk)
|
||||
local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk)
|
||||
local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk)
|
||||
-- local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums)
|
||||
-- sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk)
|
||||
-- local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk)
|
||||
-- local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk)
|
||||
-- local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk)
|
||||
local reshaped = nn.Reshape(4, rnn_size)(all_input_sums)
|
||||
local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4)
|
||||
local in_gate = nn.Sigmoid()(n1)
|
||||
local forget_gate = nn.Sigmoid()(n2)
|
||||
local out_gate = nn.Sigmoid()(n3)
|
||||
-- decode the write inputs
|
||||
local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums)
|
||||
in_transform = nn.Tanh()(in_transform)
|
||||
-- local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums)
|
||||
-- in_transform = nn.Tanh()(in_transform)
|
||||
local in_transform = nn.Tanh()(n4)
|
||||
-- perform the LSTM update
|
||||
local next_c = nn.CAddTable()({
|
||||
nn.CMulTable()({forget_gate, prev_c}),
|
||||
@ -54,7 +60,7 @@ function LSTM.lstm(input_size, rnn_size, n, dropout)
|
||||
-- set up the decoder
|
||||
local top_h = outputs[#outputs]
|
||||
if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end
|
||||
local proj = nn.Linear(rnn_size, input_size)(top_h)
|
||||
local proj = nn.Linear(rnn_size, input_size)(top_h):annotate{name='decoder'}
|
||||
local logsoft = nn.LogSoftMax()(proj)
|
||||
table.insert(outputs, logsoft)
|
||||
|
||||
|
30
train.lua
30
train.lua
@ -54,7 +54,7 @@ cmd:option('-init_from', '', 'initialize network parameters from checkpoint at t
|
||||
-- bookkeeping
|
||||
cmd:option('-seed',123,'torch manual random number generator seed')
|
||||
cmd:option('-print_every',1,'how many steps/minibatches between printing out the loss')
|
||||
cmd:option('-eval_val_every',1000,'every how many iterations should we evaluate on validation data?')
|
||||
cmd:option('-eval_val_every',2000,'every how many iterations should we evaluate on validation data?')
|
||||
cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written')
|
||||
cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/')
|
||||
-- GPU/CPU
|
||||
@ -62,9 +62,9 @@ cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
|
||||
cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
|
||||
-- Scheduled Sampling
|
||||
cmd:option('-use_ss', 1, 'whether use scheduled sampling during training')
|
||||
cmd:option('-start_ss', 1, 'start amount of truth to be github to the model when using ss')
|
||||
cmd:option('-decay_ss', 0.01666, 'ss amount decay rate of each epoch')
|
||||
cmd:option('-min_ss', 0.5, 'minimum amount of truth to be given to the model when using ss')
|
||||
cmd:option('-start_ss', 1, 'start amount of truth data to be given to the model when using ss')
|
||||
cmd:option('-decay_ss', 0.005, 'ss amount decay rate of each epoch')
|
||||
cmd:option('-min_ss', 0.9, 'minimum amount of truth data to be given to the model when using ss')
|
||||
cmd:text()
|
||||
|
||||
-- parse input params
|
||||
@ -165,11 +165,23 @@ end
|
||||
|
||||
-- put the above things into one flattened parameters tensor
|
||||
params, grad_params = model_utils.combine_all_parameters(protos.rnn)
|
||||
|
||||
-- initialization
|
||||
if do_random_init then
|
||||
params:uniform(-0.08, 0.08) -- small numbers uniform
|
||||
end
|
||||
-- initialize the LSTM forget gates with slightly higher biases to encourage remembering in the beginning
|
||||
if opt.model == 'lstm' then
|
||||
for layer_idx = 1, opt.num_layers do
|
||||
for _,node in ipairs(protos.rnn.forwardnodes) do
|
||||
if node.data.annotations.name == "i2h_" .. layer_idx then
|
||||
print('setting forget gate biases to 1 in LSTM layer ' .. layer_idx)
|
||||
-- the gates are, in order, i,f,o,g, so f is the 2nd block of weights
|
||||
node.data.module.bias[{{opt.rnn_size+1, 2*opt.rnn_size}}]:fill(1.0)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
print('number of parameters in the model: ' .. params:nElement())
|
||||
-- make a bunch of clones after flattening, as that reallocates memory
|
||||
@ -212,7 +224,7 @@ function eval_split(split_index, max_batches)
|
||||
end
|
||||
-- carry over lstm state
|
||||
rnn_state[0] = rnn_state[#rnn_state]
|
||||
print(i .. '/' .. n .. '...')
|
||||
-- print(i .. '/' .. n .. '...')
|
||||
end
|
||||
|
||||
loss = loss / opt.seq_length / n
|
||||
@ -244,7 +256,6 @@ function feval(x)
|
||||
local loss = 0
|
||||
for t=1,opt.seq_length do
|
||||
clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag)
|
||||
-- flip a coin to decide weather use scheduled sampling
|
||||
if opt.use_ss == 1 and t > 1 and math.random() > ss_current then
|
||||
local probs = torch.exp(predictions[t-1]):squeeze()
|
||||
_,samples = torch.max(probs,2)
|
||||
@ -252,6 +263,7 @@ function feval(x)
|
||||
else
|
||||
xx = x[{{}, t}]
|
||||
end
|
||||
-- print(x[{{},t}])
|
||||
local lst = clones.rnn[t]:forward{xx, unpack(rnn_state[t-1])}
|
||||
rnn_state[t] = {}
|
||||
for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end -- extract the state, without output
|
||||
@ -344,6 +356,10 @@ for i = 1, iterations do
|
||||
if i % 10 == 0 then collectgarbage() end
|
||||
|
||||
-- handle early stopping if things are going really bad
|
||||
if loss[1] ~= loss[1] then
|
||||
print('loss is NaN. This usually indicates a bug. Please check the issues page for existing issues, or create a new issue, if none exist. Ideally, please state: your operating system, 32-bit/64-bit, your blas version, cpu/cuda/cl?')
|
||||
break -- halt
|
||||
end
|
||||
if loss0 == nil then loss0 = loss[1] end
|
||||
if loss[1] > loss0 * 3 then
|
||||
print('loss is exploding, aborting.')
|
||||
|
Loading…
Reference in New Issue
Block a user