fixed UNKNOW in sample.lua

This commit is contained in:
Jeff Zhang 2015-07-15 11:24:21 +08:00
parent e99bb2f368
commit 773a174534
2 changed files with 19 additions and 14 deletions

View File

@ -86,6 +86,7 @@ for L = 1,checkpoint.opt.num_layers do
end end
state_size = #current_state state_size = #current_state
-- parse characters from a string
function get_char(str) function get_char(str)
local len = #str local len = #str
local left = 0 local left = 0
@ -108,6 +109,7 @@ function get_char(str)
left = left + i left = left + i
unordered[#unordered+1] = tmpString unordered[#unordered+1] = tmpString
end end
return unordered
end end
-- do a few seeded timesteps -- do a few seeded timesteps
@ -116,8 +118,7 @@ if string.len(seed_text) > 0 then
gprint('seeding with ' .. seed_text) gprint('seeding with ' .. seed_text)
gprint('--------------------------') gprint('--------------------------')
local chars = get_char(seed_text) local chars = get_char(seed_text)
print(chars) for i,c in ipairs(chars) do
for i,c in ipairs(chars)'.' do
prev_char = torch.Tensor{vocab[c]} prev_char = torch.Tensor{vocab[c]}
io.write(ivocab[prev_char[1]]) io.write(ivocab[prev_char[1]])
if opt.gpuid >= 0 then prev_char = prev_char:cuda() end if opt.gpuid >= 0 then prev_char = prev_char:cuda() end
@ -139,17 +140,21 @@ end
for i=1, opt.length do for i=1, opt.length do
-- log probabilities from the previous timestep -- log probabilities from the previous timestep
if opt.sample == 0 then -- make sure the output char is not UNKNOW
-- use argmax prev_char = 'UNKNOW'
local _, prev_char_ = prediction:max(2) while(prev_char == 'UNKNOW') do
prev_char = prev_char_:resize(1) if opt.sample == 0 then
else -- use argmax
-- use sampling local _, prev_char_ = prediction:max(2)
prediction:div(opt.temperature) -- scale by temperature prev_char = prev_char_:resize(1)
local probs = torch.exp(prediction):squeeze() else
probs:div(torch.sum(probs)) -- renormalize so probs sum to one -- use sampling
prev_char = torch.multinomial(probs:float(), 1):resize(1):float() prediction:div(opt.temperature) -- scale by temperature
end local probs = torch.exp(prediction):squeeze()
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
end
end
-- forward the rnn for next character -- forward the rnn for next character
local lst = protos.rnn:forward{prev_char, unpack(current_state)} local lst = protos.rnn:forward{prev_char, unpack(current_state)}

View File

@ -32,6 +32,7 @@ cmd:text()
cmd:text('Options') cmd:text('Options')
-- data -- data
cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain the file input.txt with input data') cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain the file input.txt with input data')
cmd:option('-min_freq',0,'min frequent of character')
-- model params -- model params
cmd:option('-rnn_size', 128, 'size of LSTM internal state') cmd:option('-rnn_size', 128, 'size of LSTM internal state')
cmd:option('-num_layers', 2, 'number of layers in the LSTM') cmd:option('-num_layers', 2, 'number of layers in the LSTM')
@ -59,7 +60,6 @@ cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be in
-- GPU/CPU -- GPU/CPU
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
cmd:option('-opencl',0,'use OpenCL (instead of CUDA)') cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
cmd:option('-min_freq',0,'min frequent of character')
cmd:text() cmd:text()
-- parse input params -- parse input params