fix bug & add stop option
This commit is contained in:
parent
773a174534
commit
2d3d841f8f
29
sample.lua
29
sample.lua
@ -28,7 +28,8 @@ cmd:argument('-model','model checkpoint to use for sampling')
|
|||||||
cmd:option('-seed',123,'random number generator\'s seed')
|
cmd:option('-seed',123,'random number generator\'s seed')
|
||||||
cmd:option('-sample',1,' 0 to use max at each timestep, 1 to sample at each timestep')
|
cmd:option('-sample',1,' 0 to use max at each timestep, 1 to sample at each timestep')
|
||||||
cmd:option('-primetext',"",'used as a prompt to "seed" the state of the LSTM using a given sequence, before we sample.')
|
cmd:option('-primetext',"",'used as a prompt to "seed" the state of the LSTM using a given sequence, before we sample.')
|
||||||
cmd:option('-length',2000,'number of characters to sample')
|
cmd:option('-length',2000,'max number of characters to sample')
|
||||||
|
cmd:option('-stop','\\n\\n\\n\\n\\n','stop sampling when detected')
|
||||||
cmd:option('-temperature',1,'temperature of sampling')
|
cmd:option('-temperature',1,'temperature of sampling')
|
||||||
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
|
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
|
||||||
cmd:option('-verbose',1,'set to 0 to ONLY print the sampled text, no diagnostics')
|
cmd:option('-verbose',1,'set to 0 to ONLY print the sampled text, no diagnostics')
|
||||||
@ -137,22 +138,24 @@ else
|
|||||||
end
|
end
|
||||||
|
|
||||||
-- start sampling/argmaxing
|
-- start sampling/argmaxing
|
||||||
|
result = ''
|
||||||
for i=1, opt.length do
|
for i=1, opt.length do
|
||||||
|
|
||||||
-- log probabilities from the previous timestep
|
-- log probabilities from the previous timestep
|
||||||
-- make sure the output char is not UNKNOW
|
-- make sure the output char is not UNKNOW
|
||||||
prev_char = 'UNKNOW'
|
if opt.sample == 0 then
|
||||||
while(prev_char == 'UNKNOW') do
|
-- use argmax
|
||||||
if opt.sample == 0 then
|
local _, prev_char_ = prediction:max(2)
|
||||||
-- use argmax
|
prev_char = prev_char_:resize(1)
|
||||||
local _, prev_char_ = prediction:max(2)
|
else
|
||||||
prev_char = prev_char_:resize(1)
|
-- use sampling
|
||||||
else
|
real_char = 'UNKNOW'
|
||||||
-- use sampling
|
while(real_char == 'UNKNOW') do
|
||||||
prediction:div(opt.temperature) -- scale by temperature
|
prediction:div(opt.temperature) -- scale by temperature
|
||||||
local probs = torch.exp(prediction):squeeze()
|
local probs = torch.exp(prediction):squeeze()
|
||||||
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
|
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
|
||||||
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
|
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
|
||||||
|
real_char = ivocab[prev_char[1]]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -162,7 +165,13 @@ for i=1, opt.length do
|
|||||||
for i=1,state_size do table.insert(current_state, lst[i]) end
|
for i=1,state_size do table.insert(current_state, lst[i]) end
|
||||||
prediction = lst[#lst] -- last element holds the log probabilities
|
prediction = lst[#lst] -- last element holds the log probabilities
|
||||||
|
|
||||||
io.write(ivocab[prev_char[1]])
|
-- io.write(ivocab[prev_char[1]])
|
||||||
|
result = result .. ivocab[prev_char[1]]
|
||||||
|
|
||||||
|
-- in my data, five \n represent the end of each document
|
||||||
|
-- so count \n to stop sampling
|
||||||
|
if string.find(result, opt.stop) then break end
|
||||||
end
|
end
|
||||||
|
io.write(result)
|
||||||
io.write('\n') io.flush()
|
io.write('\n') io.flush()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user