diff --git a/sample.lua b/sample.lua index 20ac4da..39c7e68 100644 --- a/sample.lua +++ b/sample.lua @@ -86,6 +86,7 @@ for L = 1,checkpoint.opt.num_layers do end state_size = #current_state +-- parse characters from a string function get_char(str) local len = #str local left = 0 @@ -108,6 +109,7 @@ function get_char(str) left = left + i unordered[#unordered+1] = tmpString end + return unordered end -- do a few seeded timesteps @@ -116,8 +118,7 @@ if string.len(seed_text) > 0 then gprint('seeding with ' .. seed_text) gprint('--------------------------') local chars = get_char(seed_text) - print(chars) - for i,c in ipairs(chars)'.' do + for i,c in ipairs(chars) do prev_char = torch.Tensor{vocab[c]} io.write(ivocab[prev_char[1]]) if opt.gpuid >= 0 then prev_char = prev_char:cuda() end @@ -139,17 +140,21 @@ end for i=1, opt.length do -- log probabilities from the previous timestep - if opt.sample == 0 then - -- use argmax - local _, prev_char_ = prediction:max(2) - prev_char = prev_char_:resize(1) - else - -- use sampling - prediction:div(opt.temperature) -- scale by temperature - local probs = torch.exp(prediction):squeeze() - probs:div(torch.sum(probs)) -- renormalize so probs sum to one - prev_char = torch.multinomial(probs:float(), 1):resize(1):float() - end + -- make sure the output char is not UNKNOW + prev_char = 'UNKNOW' + while(prev_char == 'UNKNOW') do + if opt.sample == 0 then + -- use argmax + local _, prev_char_ = prediction:max(2) + prev_char = prev_char_:resize(1) + else + -- use sampling + prediction:div(opt.temperature) -- scale by temperature + local probs = torch.exp(prediction):squeeze() + probs:div(torch.sum(probs)) -- renormalize so probs sum to one + prev_char = torch.multinomial(probs:float(), 1):resize(1):float() + end + end -- forward the rnn for next character local lst = protos.rnn:forward{prev_char, unpack(current_state)} diff --git a/train.lua b/train.lua index 0e332be..abe3680 100644 --- a/train.lua +++ b/train.lua @@ -32,6 +32,7 @@ cmd:text() cmd:text('Options') -- data cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain the file input.txt with input data') +cmd:option('-min_freq',0,'min frequent of character') -- model params cmd:option('-rnn_size', 128, 'size of LSTM internal state') cmd:option('-num_layers', 2, 'number of layers in the LSTM') @@ -59,7 +60,6 @@ cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be in -- GPU/CPU cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') cmd:option('-opencl',0,'use OpenCL (instead of CUDA)') -cmd:option('-min_freq',0,'min frequent of character') cmd:text() -- parse input params