use biases of 1.0 for the forget gate

This commit is contained in:
Jeff Zhang 2015-09-25 10:24:38 +08:00
parent 1b05d91e23
commit 0a45fbe230
2 changed files with 40 additions and 18 deletions

View File

@ -27,18 +27,24 @@ function LSTM.lstm(input_size, rnn_size, n, dropout)
input_size_L = rnn_size input_size_L = rnn_size
end end
-- evaluate the input sums at once for efficiency -- evaluate the input sums at once for efficiency
local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x) local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L}
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h) local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L}
local all_input_sums = nn.CAddTable()({i2h, h2h}) local all_input_sums = nn.CAddTable()({i2h, h2h})
-- decode the gates -- decode the gates
local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums) -- local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums)
sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk) -- sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk)
local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk) -- local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk)
local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk) -- local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk)
local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk) -- local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk)
local reshaped = nn.Reshape(4, rnn_size)(all_input_sums)
local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4)
local in_gate = nn.Sigmoid()(n1)
local forget_gate = nn.Sigmoid()(n2)
local out_gate = nn.Sigmoid()(n3)
-- decode the write inputs -- decode the write inputs
local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums) -- local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums)
in_transform = nn.Tanh()(in_transform) -- in_transform = nn.Tanh()(in_transform)
local in_transform = nn.Tanh()(n4)
-- perform the LSTM update -- perform the LSTM update
local next_c = nn.CAddTable()({ local next_c = nn.CAddTable()({
nn.CMulTable()({forget_gate, prev_c}), nn.CMulTable()({forget_gate, prev_c}),
@ -54,7 +60,7 @@ function LSTM.lstm(input_size, rnn_size, n, dropout)
-- set up the decoder -- set up the decoder
local top_h = outputs[#outputs] local top_h = outputs[#outputs]
if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end
local proj = nn.Linear(rnn_size, input_size)(top_h) local proj = nn.Linear(rnn_size, input_size)(top_h):annotate{name='decoder'}
local logsoft = nn.LogSoftMax()(proj) local logsoft = nn.LogSoftMax()(proj)
table.insert(outputs, logsoft) table.insert(outputs, logsoft)

View File

@ -54,7 +54,7 @@ cmd:option('-init_from', '', 'initialize network parameters from checkpoint at t
-- bookkeeping -- bookkeeping
cmd:option('-seed',123,'torch manual random number generator seed') cmd:option('-seed',123,'torch manual random number generator seed')
cmd:option('-print_every',1,'how many steps/minibatches between printing out the loss') cmd:option('-print_every',1,'how many steps/minibatches between printing out the loss')
cmd:option('-eval_val_every',1000,'every how many iterations should we evaluate on validation data?') cmd:option('-eval_val_every',2000,'every how many iterations should we evaluate on validation data?')
cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written') cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written')
cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/') cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/')
-- GPU/CPU -- GPU/CPU
@ -62,9 +62,9 @@ cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
cmd:option('-opencl',0,'use OpenCL (instead of CUDA)') cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
-- Scheduled Sampling -- Scheduled Sampling
cmd:option('-use_ss', 1, 'whether use scheduled sampling during training') cmd:option('-use_ss', 1, 'whether use scheduled sampling during training')
cmd:option('-start_ss', 1, 'start amount of truth to be github to the model when using ss') cmd:option('-start_ss', 1, 'start amount of truth data to be given to the model when using ss')
cmd:option('-decay_ss', 0.01666, 'ss amount decay rate of each epoch') cmd:option('-decay_ss', 0.005, 'ss amount decay rate of each epoch')
cmd:option('-min_ss', 0.5, 'minimum amount of truth to be given to the model when using ss') cmd:option('-min_ss', 0.9, 'minimum amount of truth data to be given to the model when using ss')
cmd:text() cmd:text()
-- parse input params -- parse input params
@ -165,11 +165,23 @@ end
-- put the above things into one flattened parameters tensor -- put the above things into one flattened parameters tensor
params, grad_params = model_utils.combine_all_parameters(protos.rnn) params, grad_params = model_utils.combine_all_parameters(protos.rnn)
-- initialization -- initialization
if do_random_init then if do_random_init then
params:uniform(-0.08, 0.08) -- small numbers uniform params:uniform(-0.08, 0.08) -- small numbers uniform
end end
-- initialize the LSTM forget gates with slightly higher biases to encourage remembering in the beginning
if opt.model == 'lstm' then
for layer_idx = 1, opt.num_layers do
for _,node in ipairs(protos.rnn.forwardnodes) do
if node.data.annotations.name == "i2h_" .. layer_idx then
print('setting forget gate biases to 1 in LSTM layer ' .. layer_idx)
-- the gates are, in order, i,f,o,g, so f is the 2nd block of weights
node.data.module.bias[{{opt.rnn_size+1, 2*opt.rnn_size}}]:fill(1.0)
end
end
end
end
print('number of parameters in the model: ' .. params:nElement()) print('number of parameters in the model: ' .. params:nElement())
-- make a bunch of clones after flattening, as that reallocates memory -- make a bunch of clones after flattening, as that reallocates memory
@ -212,7 +224,7 @@ function eval_split(split_index, max_batches)
end end
-- carry over lstm state -- carry over lstm state
rnn_state[0] = rnn_state[#rnn_state] rnn_state[0] = rnn_state[#rnn_state]
print(i .. '/' .. n .. '...') -- print(i .. '/' .. n .. '...')
end end
loss = loss / opt.seq_length / n loss = loss / opt.seq_length / n
@ -244,7 +256,6 @@ function feval(x)
local loss = 0 local loss = 0
for t=1,opt.seq_length do for t=1,opt.seq_length do
clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag) clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag)
-- flip a coin to decide weather use scheduled sampling
if opt.use_ss == 1 and t > 1 and math.random() > ss_current then if opt.use_ss == 1 and t > 1 and math.random() > ss_current then
local probs = torch.exp(predictions[t-1]):squeeze() local probs = torch.exp(predictions[t-1]):squeeze()
_,samples = torch.max(probs,2) _,samples = torch.max(probs,2)
@ -252,6 +263,7 @@ function feval(x)
else else
xx = x[{{}, t}] xx = x[{{}, t}]
end end
-- print(x[{{},t}])
local lst = clones.rnn[t]:forward{xx, unpack(rnn_state[t-1])} local lst = clones.rnn[t]:forward{xx, unpack(rnn_state[t-1])}
rnn_state[t] = {} rnn_state[t] = {}
for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end -- extract the state, without output for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end -- extract the state, without output
@ -344,6 +356,10 @@ for i = 1, iterations do
if i % 10 == 0 then collectgarbage() end if i % 10 == 0 then collectgarbage() end
-- handle early stopping if things are going really bad -- handle early stopping if things are going really bad
if loss[1] ~= loss[1] then
print('loss is NaN. This usually indicates a bug. Please check the issues page for existing issues, or create a new issue, if none exist. Ideally, please state: your operating system, 32-bit/64-bit, your blas version, cpu/cuda/cl?')
break -- halt
end
if loss0 == nil then loss0 = loss[1] end if loss0 == nil then loss0 = loss[1] end
if loss[1] > loss0 * 3 then if loss[1] > loss0 * 3 then
print('loss is exploding, aborting.') print('loss is exploding, aborting.')