fixed UNKNOW in sample.lua
This commit is contained in:
parent
e99bb2f368
commit
773a174534
31
sample.lua
31
sample.lua
@ -86,6 +86,7 @@ for L = 1,checkpoint.opt.num_layers do
|
||||
end
|
||||
state_size = #current_state
|
||||
|
||||
-- parse characters from a string
|
||||
function get_char(str)
|
||||
local len = #str
|
||||
local left = 0
|
||||
@ -108,6 +109,7 @@ function get_char(str)
|
||||
left = left + i
|
||||
unordered[#unordered+1] = tmpString
|
||||
end
|
||||
return unordered
|
||||
end
|
||||
|
||||
-- do a few seeded timesteps
|
||||
@ -116,8 +118,7 @@ if string.len(seed_text) > 0 then
|
||||
gprint('seeding with ' .. seed_text)
|
||||
gprint('--------------------------')
|
||||
local chars = get_char(seed_text)
|
||||
print(chars)
|
||||
for i,c in ipairs(chars)'.' do
|
||||
for i,c in ipairs(chars) do
|
||||
prev_char = torch.Tensor{vocab[c]}
|
||||
io.write(ivocab[prev_char[1]])
|
||||
if opt.gpuid >= 0 then prev_char = prev_char:cuda() end
|
||||
@ -139,17 +140,21 @@ end
|
||||
for i=1, opt.length do
|
||||
|
||||
-- log probabilities from the previous timestep
|
||||
if opt.sample == 0 then
|
||||
-- use argmax
|
||||
local _, prev_char_ = prediction:max(2)
|
||||
prev_char = prev_char_:resize(1)
|
||||
else
|
||||
-- use sampling
|
||||
prediction:div(opt.temperature) -- scale by temperature
|
||||
local probs = torch.exp(prediction):squeeze()
|
||||
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
|
||||
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
|
||||
end
|
||||
-- make sure the output char is not UNKNOW
|
||||
prev_char = 'UNKNOW'
|
||||
while(prev_char == 'UNKNOW') do
|
||||
if opt.sample == 0 then
|
||||
-- use argmax
|
||||
local _, prev_char_ = prediction:max(2)
|
||||
prev_char = prev_char_:resize(1)
|
||||
else
|
||||
-- use sampling
|
||||
prediction:div(opt.temperature) -- scale by temperature
|
||||
local probs = torch.exp(prediction):squeeze()
|
||||
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
|
||||
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
|
||||
end
|
||||
end
|
||||
|
||||
-- forward the rnn for next character
|
||||
local lst = protos.rnn:forward{prev_char, unpack(current_state)}
|
||||
|
@ -32,6 +32,7 @@ cmd:text()
|
||||
cmd:text('Options')
|
||||
-- data
|
||||
cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain the file input.txt with input data')
|
||||
cmd:option('-min_freq',0,'min frequent of character')
|
||||
-- model params
|
||||
cmd:option('-rnn_size', 128, 'size of LSTM internal state')
|
||||
cmd:option('-num_layers', 2, 'number of layers in the LSTM')
|
||||
@ -59,7 +60,6 @@ cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be in
|
||||
-- GPU/CPU
|
||||
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
|
||||
cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
|
||||
cmd:option('-min_freq',0,'min frequent of character')
|
||||
cmd:text()
|
||||
|
||||
-- parse input params
|
||||
|
Loading…
Reference in New Issue
Block a user