diff --git a/model/LSTM.lua b/model/LSTM.lua index bcaf39e..0b0d4cb 100644 --- a/model/LSTM.lua +++ b/model/LSTM.lua @@ -27,18 +27,24 @@ function LSTM.lstm(input_size, rnn_size, n, dropout) input_size_L = rnn_size end -- evaluate the input sums at once for efficiency - local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x) - local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h) + local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L} + local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L} local all_input_sums = nn.CAddTable()({i2h, h2h}) -- decode the gates - local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums) - sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk) - local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk) - local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk) - local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk) + -- local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums) + -- sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk) + -- local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk) + -- local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk) + -- local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk) + local reshaped = nn.Reshape(4, rnn_size)(all_input_sums) + local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) + local in_gate = nn.Sigmoid()(n1) + local forget_gate = nn.Sigmoid()(n2) + local out_gate = nn.Sigmoid()(n3) -- decode the write inputs - local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums) - in_transform = nn.Tanh()(in_transform) + -- local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums) + -- in_transform = nn.Tanh()(in_transform) + local in_transform = nn.Tanh()(n4) -- perform the LSTM update local next_c = nn.CAddTable()({ nn.CMulTable()({forget_gate, prev_c}), @@ -54,7 +60,7 @@ function LSTM.lstm(input_size, rnn_size, n, dropout) -- set up the decoder local top_h = outputs[#outputs] if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end - local proj = nn.Linear(rnn_size, input_size)(top_h) + local proj = nn.Linear(rnn_size, input_size)(top_h):annotate{name='decoder'} local logsoft = nn.LogSoftMax()(proj) table.insert(outputs, logsoft) diff --git a/train.lua b/train.lua index 900d1cc..06c064a 100644 --- a/train.lua +++ b/train.lua @@ -54,7 +54,7 @@ cmd:option('-init_from', '', 'initialize network parameters from checkpoint at t -- bookkeeping cmd:option('-seed',123,'torch manual random number generator seed') cmd:option('-print_every',1,'how many steps/minibatches between printing out the loss') -cmd:option('-eval_val_every',1000,'every how many iterations should we evaluate on validation data?') +cmd:option('-eval_val_every',2000,'every how many iterations should we evaluate on validation data?') cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written') cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/') -- GPU/CPU @@ -62,9 +62,9 @@ cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') cmd:option('-opencl',0,'use OpenCL (instead of CUDA)') -- Scheduled Sampling cmd:option('-use_ss', 1, 'whether use scheduled sampling during training') -cmd:option('-start_ss', 1, 'start amount of truth to be github to the model when using ss') -cmd:option('-decay_ss', 0.01666, 'ss amount decay rate of each epoch') -cmd:option('-min_ss', 0.5, 'minimum amount of truth to be given to the model when using ss') +cmd:option('-start_ss', 1, 'start amount of truth data to be given to the model when using ss') +cmd:option('-decay_ss', 0.005, 'ss amount decay rate of each epoch') +cmd:option('-min_ss', 0.9, 'minimum amount of truth data to be given to the model when using ss') cmd:text() -- parse input params @@ -165,11 +165,23 @@ end -- put the above things into one flattened parameters tensor params, grad_params = model_utils.combine_all_parameters(protos.rnn) - -- initialization if do_random_init then -params:uniform(-0.08, 0.08) -- small numbers uniform + params:uniform(-0.08, 0.08) -- small numbers uniform end +-- initialize the LSTM forget gates with slightly higher biases to encourage remembering in the beginning +if opt.model == 'lstm' then + for layer_idx = 1, opt.num_layers do + for _,node in ipairs(protos.rnn.forwardnodes) do + if node.data.annotations.name == "i2h_" .. layer_idx then + print('setting forget gate biases to 1 in LSTM layer ' .. layer_idx) + -- the gates are, in order, i,f,o,g, so f is the 2nd block of weights + node.data.module.bias[{{opt.rnn_size+1, 2*opt.rnn_size}}]:fill(1.0) + end + end + end +end + print('number of parameters in the model: ' .. params:nElement()) -- make a bunch of clones after flattening, as that reallocates memory @@ -212,7 +224,7 @@ function eval_split(split_index, max_batches) end -- carry over lstm state rnn_state[0] = rnn_state[#rnn_state] - print(i .. '/' .. n .. '...') + -- print(i .. '/' .. n .. '...') end loss = loss / opt.seq_length / n @@ -244,7 +256,6 @@ function feval(x) local loss = 0 for t=1,opt.seq_length do clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag) - -- flip a coin to decide weather use scheduled sampling if opt.use_ss == 1 and t > 1 and math.random() > ss_current then local probs = torch.exp(predictions[t-1]):squeeze() _,samples = torch.max(probs,2) @@ -252,6 +263,7 @@ function feval(x) else xx = x[{{}, t}] end + -- print(x[{{},t}]) local lst = clones.rnn[t]:forward{xx, unpack(rnn_state[t-1])} rnn_state[t] = {} for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end -- extract the state, without output @@ -344,6 +356,10 @@ for i = 1, iterations do if i % 10 == 0 then collectgarbage() end -- handle early stopping if things are going really bad + if loss[1] ~= loss[1] then + print('loss is NaN. This usually indicates a bug. Please check the issues page for existing issues, or create a new issue, if none exist. Ideally, please state: your operating system, 32-bit/64-bit, your blas version, cpu/cuda/cl?') + break -- halt + end if loss0 == nil then loss0 = loss[1] end if loss[1] > loss0 * 3 then print('loss is exploding, aborting.')