fixed UNKNOW in sample.lua

This commit is contained in:
Jeff Zhang 2015-07-15 11:24:21 +08:00
parent e99bb2f368
commit 773a174534
2 changed files with 19 additions and 14 deletions

View File

@ -86,6 +86,7 @@ for L = 1,checkpoint.opt.num_layers do
end
state_size = #current_state
-- parse characters from a string
function get_char(str)
local len = #str
local left = 0
@ -108,6 +109,7 @@ function get_char(str)
left = left + i
unordered[#unordered+1] = tmpString
end
return unordered
end
-- do a few seeded timesteps
@ -116,8 +118,7 @@ if string.len(seed_text) > 0 then
gprint('seeding with ' .. seed_text)
gprint('--------------------------')
local chars = get_char(seed_text)
print(chars)
for i,c in ipairs(chars)'.' do
for i,c in ipairs(chars) do
prev_char = torch.Tensor{vocab[c]}
io.write(ivocab[prev_char[1]])
if opt.gpuid >= 0 then prev_char = prev_char:cuda() end
@ -139,6 +140,9 @@ end
for i=1, opt.length do
-- log probabilities from the previous timestep
-- make sure the output char is not UNKNOW
prev_char = 'UNKNOW'
while(prev_char == 'UNKNOW') do
if opt.sample == 0 then
-- use argmax
local _, prev_char_ = prediction:max(2)
@ -150,6 +154,7 @@ for i=1, opt.length do
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
end
end
-- forward the rnn for next character
local lst = protos.rnn:forward{prev_char, unpack(current_state)}

View File

@ -32,6 +32,7 @@ cmd:text()
cmd:text('Options')
-- data
cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain the file input.txt with input data')
cmd:option('-min_freq',0,'min frequent of character')
-- model params
cmd:option('-rnn_size', 128, 'size of LSTM internal state')
cmd:option('-num_layers', 2, 'number of layers in the LSTM')
@ -59,7 +60,6 @@ cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be in
-- GPU/CPU
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
cmd:option('-min_freq',0,'min frequent of character')
cmd:text()
-- parse input params