From 2d3d841f8f20c3aad31d00d9555f6521480452f7 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 15 Jul 2015 12:15:45 +0800 Subject: [PATCH] fix bug & add stop option --- sample.lua | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/sample.lua b/sample.lua index 39c7e68..9dd9024 100644 --- a/sample.lua +++ b/sample.lua @@ -28,7 +28,8 @@ cmd:argument('-model','model checkpoint to use for sampling') cmd:option('-seed',123,'random number generator\'s seed') cmd:option('-sample',1,' 0 to use max at each timestep, 1 to sample at each timestep') cmd:option('-primetext',"",'used as a prompt to "seed" the state of the LSTM using a given sequence, before we sample.') -cmd:option('-length',2000,'number of characters to sample') +cmd:option('-length',2000,'max number of characters to sample') +cmd:option('-stop','\\n\\n\\n\\n\\n','stop sampling when detected') cmd:option('-temperature',1,'temperature of sampling') cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') cmd:option('-verbose',1,'set to 0 to ONLY print the sampled text, no diagnostics') @@ -137,22 +138,24 @@ else end -- start sampling/argmaxing +result = '' for i=1, opt.length do -- log probabilities from the previous timestep -- make sure the output char is not UNKNOW - prev_char = 'UNKNOW' - while(prev_char == 'UNKNOW') do - if opt.sample == 0 then - -- use argmax - local _, prev_char_ = prediction:max(2) - prev_char = prev_char_:resize(1) - else - -- use sampling + if opt.sample == 0 then + -- use argmax + local _, prev_char_ = prediction:max(2) + prev_char = prev_char_:resize(1) + else + -- use sampling + real_char = 'UNKNOW' + while(real_char == 'UNKNOW') do prediction:div(opt.temperature) -- scale by temperature local probs = torch.exp(prediction):squeeze() probs:div(torch.sum(probs)) -- renormalize so probs sum to one prev_char = torch.multinomial(probs:float(), 1):resize(1):float() + real_char = ivocab[prev_char[1]] end end @@ -162,7 +165,13 @@ for i=1, opt.length do for i=1,state_size do table.insert(current_state, lst[i]) end prediction = lst[#lst] -- last element holds the log probabilities - io.write(ivocab[prev_char[1]]) + -- io.write(ivocab[prev_char[1]]) + result = result .. ivocab[prev_char[1]] + + -- in my data, five \n represent the end of each document + -- so count \n to stop sampling + if string.find(result, opt.stop) then break end end +io.write(result) io.write('\n') io.flush()