fix bug & add stop option

This commit is contained in:
Jeff Zhang 2015-07-15 12:15:45 +08:00
parent 773a174534
commit 2d3d841f8f

View File

@ -28,7 +28,8 @@ cmd:argument('-model','model checkpoint to use for sampling')
cmd:option('-seed',123,'random number generator\'s seed')
cmd:option('-sample',1,' 0 to use max at each timestep, 1 to sample at each timestep')
cmd:option('-primetext',"",'used as a prompt to "seed" the state of the LSTM using a given sequence, before we sample.')
cmd:option('-length',2000,'number of characters to sample')
cmd:option('-length',2000,'max number of characters to sample')
cmd:option('-stop','\\n\\n\\n\\n\\n','stop sampling when detected')
cmd:option('-temperature',1,'temperature of sampling')
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
cmd:option('-verbose',1,'set to 0 to ONLY print the sampled text, no diagnostics')
@ -137,22 +138,24 @@ else
end
-- start sampling/argmaxing
result = ''
for i=1, opt.length do
-- log probabilities from the previous timestep
-- make sure the output char is not UNKNOW
prev_char = 'UNKNOW'
while(prev_char == 'UNKNOW') do
if opt.sample == 0 then
-- use argmax
local _, prev_char_ = prediction:max(2)
prev_char = prev_char_:resize(1)
else
-- use sampling
if opt.sample == 0 then
-- use argmax
local _, prev_char_ = prediction:max(2)
prev_char = prev_char_:resize(1)
else
-- use sampling
real_char = 'UNKNOW'
while(real_char == 'UNKNOW') do
prediction:div(opt.temperature) -- scale by temperature
local probs = torch.exp(prediction):squeeze()
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
real_char = ivocab[prev_char[1]]
end
end
@ -162,7 +165,13 @@ for i=1, opt.length do
for i=1,state_size do table.insert(current_state, lst[i]) end
prediction = lst[#lst] -- last element holds the log probabilities
io.write(ivocab[prev_char[1]])
-- io.write(ivocab[prev_char[1]])
result = result .. ivocab[prev_char[1]]
-- in my data, five \n represent the end of each document
-- so count \n to stop sampling
if string.find(result, opt.stop) then break end
end
io.write(result)
io.write('\n') io.flush()