fix bug & add stop option
This commit is contained in:
parent
773a174534
commit
2d3d841f8f
29
sample.lua
29
sample.lua
@ -28,7 +28,8 @@ cmd:argument('-model','model checkpoint to use for sampling')
|
||||
cmd:option('-seed',123,'random number generator\'s seed')
|
||||
cmd:option('-sample',1,' 0 to use max at each timestep, 1 to sample at each timestep')
|
||||
cmd:option('-primetext',"",'used as a prompt to "seed" the state of the LSTM using a given sequence, before we sample.')
|
||||
cmd:option('-length',2000,'number of characters to sample')
|
||||
cmd:option('-length',2000,'max number of characters to sample')
|
||||
cmd:option('-stop','\\n\\n\\n\\n\\n','stop sampling when detected')
|
||||
cmd:option('-temperature',1,'temperature of sampling')
|
||||
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
|
||||
cmd:option('-verbose',1,'set to 0 to ONLY print the sampled text, no diagnostics')
|
||||
@ -137,22 +138,24 @@ else
|
||||
end
|
||||
|
||||
-- start sampling/argmaxing
|
||||
result = ''
|
||||
for i=1, opt.length do
|
||||
|
||||
-- log probabilities from the previous timestep
|
||||
-- make sure the output char is not UNKNOW
|
||||
prev_char = 'UNKNOW'
|
||||
while(prev_char == 'UNKNOW') do
|
||||
if opt.sample == 0 then
|
||||
-- use argmax
|
||||
local _, prev_char_ = prediction:max(2)
|
||||
prev_char = prev_char_:resize(1)
|
||||
else
|
||||
-- use sampling
|
||||
if opt.sample == 0 then
|
||||
-- use argmax
|
||||
local _, prev_char_ = prediction:max(2)
|
||||
prev_char = prev_char_:resize(1)
|
||||
else
|
||||
-- use sampling
|
||||
real_char = 'UNKNOW'
|
||||
while(real_char == 'UNKNOW') do
|
||||
prediction:div(opt.temperature) -- scale by temperature
|
||||
local probs = torch.exp(prediction):squeeze()
|
||||
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
|
||||
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
|
||||
real_char = ivocab[prev_char[1]]
|
||||
end
|
||||
end
|
||||
|
||||
@ -162,7 +165,13 @@ for i=1, opt.length do
|
||||
for i=1,state_size do table.insert(current_state, lst[i]) end
|
||||
prediction = lst[#lst] -- last element holds the log probabilities
|
||||
|
||||
io.write(ivocab[prev_char[1]])
|
||||
-- io.write(ivocab[prev_char[1]])
|
||||
result = result .. ivocab[prev_char[1]]
|
||||
|
||||
-- in my data, five \n represent the end of each document
|
||||
-- so count \n to stop sampling
|
||||
if string.find(result, opt.stop) then break end
|
||||
end
|
||||
io.write(result)
|
||||
io.write('\n') io.flush()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user