use biases of 1.0 for the forget gate
This commit is contained in:
parent
1b05d91e23
commit
0a45fbe230
@ -27,18 +27,24 @@ function LSTM.lstm(input_size, rnn_size, n, dropout)
|
|||||||
input_size_L = rnn_size
|
input_size_L = rnn_size
|
||||||
end
|
end
|
||||||
-- evaluate the input sums at once for efficiency
|
-- evaluate the input sums at once for efficiency
|
||||||
local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x)
|
local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L}
|
||||||
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h)
|
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L}
|
||||||
local all_input_sums = nn.CAddTable()({i2h, h2h})
|
local all_input_sums = nn.CAddTable()({i2h, h2h})
|
||||||
-- decode the gates
|
-- decode the gates
|
||||||
local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums)
|
-- local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums)
|
||||||
sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk)
|
-- sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk)
|
||||||
local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk)
|
-- local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk)
|
||||||
local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk)
|
-- local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk)
|
||||||
local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk)
|
-- local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk)
|
||||||
|
local reshaped = nn.Reshape(4, rnn_size)(all_input_sums)
|
||||||
|
local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4)
|
||||||
|
local in_gate = nn.Sigmoid()(n1)
|
||||||
|
local forget_gate = nn.Sigmoid()(n2)
|
||||||
|
local out_gate = nn.Sigmoid()(n3)
|
||||||
-- decode the write inputs
|
-- decode the write inputs
|
||||||
local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums)
|
-- local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums)
|
||||||
in_transform = nn.Tanh()(in_transform)
|
-- in_transform = nn.Tanh()(in_transform)
|
||||||
|
local in_transform = nn.Tanh()(n4)
|
||||||
-- perform the LSTM update
|
-- perform the LSTM update
|
||||||
local next_c = nn.CAddTable()({
|
local next_c = nn.CAddTable()({
|
||||||
nn.CMulTable()({forget_gate, prev_c}),
|
nn.CMulTable()({forget_gate, prev_c}),
|
||||||
@ -54,7 +60,7 @@ function LSTM.lstm(input_size, rnn_size, n, dropout)
|
|||||||
-- set up the decoder
|
-- set up the decoder
|
||||||
local top_h = outputs[#outputs]
|
local top_h = outputs[#outputs]
|
||||||
if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end
|
if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end
|
||||||
local proj = nn.Linear(rnn_size, input_size)(top_h)
|
local proj = nn.Linear(rnn_size, input_size)(top_h):annotate{name='decoder'}
|
||||||
local logsoft = nn.LogSoftMax()(proj)
|
local logsoft = nn.LogSoftMax()(proj)
|
||||||
table.insert(outputs, logsoft)
|
table.insert(outputs, logsoft)
|
||||||
|
|
||||||
|
32
train.lua
32
train.lua
@ -54,7 +54,7 @@ cmd:option('-init_from', '', 'initialize network parameters from checkpoint at t
|
|||||||
-- bookkeeping
|
-- bookkeeping
|
||||||
cmd:option('-seed',123,'torch manual random number generator seed')
|
cmd:option('-seed',123,'torch manual random number generator seed')
|
||||||
cmd:option('-print_every',1,'how many steps/minibatches between printing out the loss')
|
cmd:option('-print_every',1,'how many steps/minibatches between printing out the loss')
|
||||||
cmd:option('-eval_val_every',1000,'every how many iterations should we evaluate on validation data?')
|
cmd:option('-eval_val_every',2000,'every how many iterations should we evaluate on validation data?')
|
||||||
cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written')
|
cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written')
|
||||||
cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/')
|
cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/')
|
||||||
-- GPU/CPU
|
-- GPU/CPU
|
||||||
@ -62,9 +62,9 @@ cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
|
|||||||
cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
|
cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
|
||||||
-- Scheduled Sampling
|
-- Scheduled Sampling
|
||||||
cmd:option('-use_ss', 1, 'whether use scheduled sampling during training')
|
cmd:option('-use_ss', 1, 'whether use scheduled sampling during training')
|
||||||
cmd:option('-start_ss', 1, 'start amount of truth to be github to the model when using ss')
|
cmd:option('-start_ss', 1, 'start amount of truth data to be given to the model when using ss')
|
||||||
cmd:option('-decay_ss', 0.01666, 'ss amount decay rate of each epoch')
|
cmd:option('-decay_ss', 0.005, 'ss amount decay rate of each epoch')
|
||||||
cmd:option('-min_ss', 0.5, 'minimum amount of truth to be given to the model when using ss')
|
cmd:option('-min_ss', 0.9, 'minimum amount of truth data to be given to the model when using ss')
|
||||||
cmd:text()
|
cmd:text()
|
||||||
|
|
||||||
-- parse input params
|
-- parse input params
|
||||||
@ -165,11 +165,23 @@ end
|
|||||||
|
|
||||||
-- put the above things into one flattened parameters tensor
|
-- put the above things into one flattened parameters tensor
|
||||||
params, grad_params = model_utils.combine_all_parameters(protos.rnn)
|
params, grad_params = model_utils.combine_all_parameters(protos.rnn)
|
||||||
|
|
||||||
-- initialization
|
-- initialization
|
||||||
if do_random_init then
|
if do_random_init then
|
||||||
params:uniform(-0.08, 0.08) -- small numbers uniform
|
params:uniform(-0.08, 0.08) -- small numbers uniform
|
||||||
end
|
end
|
||||||
|
-- initialize the LSTM forget gates with slightly higher biases to encourage remembering in the beginning
|
||||||
|
if opt.model == 'lstm' then
|
||||||
|
for layer_idx = 1, opt.num_layers do
|
||||||
|
for _,node in ipairs(protos.rnn.forwardnodes) do
|
||||||
|
if node.data.annotations.name == "i2h_" .. layer_idx then
|
||||||
|
print('setting forget gate biases to 1 in LSTM layer ' .. layer_idx)
|
||||||
|
-- the gates are, in order, i,f,o,g, so f is the 2nd block of weights
|
||||||
|
node.data.module.bias[{{opt.rnn_size+1, 2*opt.rnn_size}}]:fill(1.0)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
print('number of parameters in the model: ' .. params:nElement())
|
print('number of parameters in the model: ' .. params:nElement())
|
||||||
-- make a bunch of clones after flattening, as that reallocates memory
|
-- make a bunch of clones after flattening, as that reallocates memory
|
||||||
@ -212,7 +224,7 @@ function eval_split(split_index, max_batches)
|
|||||||
end
|
end
|
||||||
-- carry over lstm state
|
-- carry over lstm state
|
||||||
rnn_state[0] = rnn_state[#rnn_state]
|
rnn_state[0] = rnn_state[#rnn_state]
|
||||||
print(i .. '/' .. n .. '...')
|
-- print(i .. '/' .. n .. '...')
|
||||||
end
|
end
|
||||||
|
|
||||||
loss = loss / opt.seq_length / n
|
loss = loss / opt.seq_length / n
|
||||||
@ -244,7 +256,6 @@ function feval(x)
|
|||||||
local loss = 0
|
local loss = 0
|
||||||
for t=1,opt.seq_length do
|
for t=1,opt.seq_length do
|
||||||
clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag)
|
clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag)
|
||||||
-- flip a coin to decide weather use scheduled sampling
|
|
||||||
if opt.use_ss == 1 and t > 1 and math.random() > ss_current then
|
if opt.use_ss == 1 and t > 1 and math.random() > ss_current then
|
||||||
local probs = torch.exp(predictions[t-1]):squeeze()
|
local probs = torch.exp(predictions[t-1]):squeeze()
|
||||||
_,samples = torch.max(probs,2)
|
_,samples = torch.max(probs,2)
|
||||||
@ -252,6 +263,7 @@ function feval(x)
|
|||||||
else
|
else
|
||||||
xx = x[{{}, t}]
|
xx = x[{{}, t}]
|
||||||
end
|
end
|
||||||
|
-- print(x[{{},t}])
|
||||||
local lst = clones.rnn[t]:forward{xx, unpack(rnn_state[t-1])}
|
local lst = clones.rnn[t]:forward{xx, unpack(rnn_state[t-1])}
|
||||||
rnn_state[t] = {}
|
rnn_state[t] = {}
|
||||||
for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end -- extract the state, without output
|
for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end -- extract the state, without output
|
||||||
@ -344,6 +356,10 @@ for i = 1, iterations do
|
|||||||
if i % 10 == 0 then collectgarbage() end
|
if i % 10 == 0 then collectgarbage() end
|
||||||
|
|
||||||
-- handle early stopping if things are going really bad
|
-- handle early stopping if things are going really bad
|
||||||
|
if loss[1] ~= loss[1] then
|
||||||
|
print('loss is NaN. This usually indicates a bug. Please check the issues page for existing issues, or create a new issue, if none exist. Ideally, please state: your operating system, 32-bit/64-bit, your blas version, cpu/cuda/cl?')
|
||||||
|
break -- halt
|
||||||
|
end
|
||||||
if loss0 == nil then loss0 = loss[1] end
|
if loss0 == nil then loss0 = loss[1] end
|
||||||
if loss[1] > loss0 * 3 then
|
if loss[1] > loss0 * 3 then
|
||||||
print('loss is exploding, aborting.')
|
print('loss is exploding, aborting.')
|
||||||
|
Loading…
Reference in New Issue
Block a user