add Scheduled Sampling at training

This commit is contained in:
Jeff Zhang 2015-09-09 10:22:58 +08:00
parent e9ad92ebfd
commit 9d37bba62c
2 changed files with 37 additions and 8 deletions

View File

@ -1,6 +1,6 @@
# char-rnn-chinese
Based on https://github.com/karpathy/char-rnn. make the code work well with Chinese.
Based on Andrej Karpathy's code https://github.com/karpathy/char-rnn and Samy Bengio's paper http://arxiv.org/abs/1506.03099
## Chinese process
Make the code can process both English and Chinese characters.
@ -10,6 +10,20 @@ This is my first touch of Lua, so the string process seems silly, but it works w
I also add an option called 'min_freq' because the vocab size in Chinese is very big, which makes the parameter num increase a lot.
So delete some rare character may help.
## Scheduled Sampling
Samy Bengio's paper [Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks](http://arxiv.org/abs/1506.03099) in NIPS15
propose a simple but power method to implove RNN.
In my experiment, I find it helps a lot to avoid overfitting and make the test loss go deeper. I only use linear decay.
Use `-use_ss` to turn on or turn off scheduled sampling, default is on. `-start_ss` is the start aomunt of real data, I suggest to use 1 because our model should learn data without noise at the very begining. `-min_ss` is also very important as too much noise will hurt performance. Finally, `-decay_ss` is the linear decay rate.
## Model conversion between cpu and gpu
I add a script to convert a model file trained by gpu to cpu model.
You can try it as follow:
```bash
$ th convert.lua gpu_model cpu_model
```
## web interface
A web demo is added for others to test model easily, based on sub/pub of redis.
I use redis because i can't found some good RPC or WebServer work well integrated with Torch.
@ -35,12 +49,6 @@ $ nohup th web_backend.lua &
$ nohup python web_server.py &
```
## Model conversion between cpu and gpu
I add a script to convert a model file trained by gpu to cpu model.
You can try it as follow:
```bash
$ th convert.lua gpu_model cpu_model
```
-----------------------------------------------
## Karpathy's raw Readme

View File

@ -60,11 +60,17 @@ cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be in
-- GPU/CPU
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
-- Scheduled Sampling
cmd:option('-use_ss', 1, 'whether use scheduled sampling during training')
cmd:option('-start_ss', 1, 'start amount of truth to be github to the model when using ss')
cmd:option('-decay_ss', 0.01666, 'ss amount decay rate of each epoch')
cmd:option('-min_ss', 0.5, 'minimum amount of truth to be given to the model when using ss')
cmd:text()
-- parse input params
opt = cmd:parse(arg)
torch.manualSeed(opt.seed)
math.randomseed(opt.seed)
-- train / val / test split for data, in fractions
local test_frac = math.max(0, 1 - (opt.train_frac + opt.val_frac))
local split_sizes = {opt.train_frac, opt.val_frac, test_frac}
@ -238,7 +244,15 @@ function feval(x)
local loss = 0
for t=1,opt.seq_length do
clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag)
local lst = clones.rnn[t]:forward{x[{{}, t}], unpack(rnn_state[t-1])}
-- flip a coin to decide weather use scheduled sampling
if opt.use_ss == 1 and t > 1 and math.random() > ss_current then
local probs = torch.exp(predictions[t-1]):squeeze()
_,samples = torch.max(probs,2)
xx = samples:view(samples:nElement())
else
xx = x[{{}, t}]
end
local lst = clones.rnn[t]:forward{xx, unpack(rnn_state[t-1])}
rnn_state[t] = {}
for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end -- extract the state, without output
predictions[t] = lst[#lst] -- last element is the prediction
@ -277,6 +291,7 @@ local optim_state = {learningRate = opt.learning_rate, alpha = opt.decay_rate}
local iterations = opt.max_epochs * loader.ntrain
local iterations_per_epoch = loader.ntrain
local loss0 = nil
ss_current = opt.start_ss
for i = 1, iterations do
local epoch = i / loader.ntrain
@ -296,6 +311,12 @@ for i = 1, iterations do
end
end
-- decay schedule sampling amount
if opt.use_ss == 1 and i % loader.ntrain == 0 and ss_current > opt.min_ss then
ss_current = opt.start_ss - opt.decay_ss * epoch
print('decay schedule sampling amount to ' .. ss_current)
end
-- every now and then or on last iteration
if i % opt.eval_val_every == 0 or i == iterations then
-- evaluate loss on validation data