fixed UNKNOW in sample.lua
This commit is contained in:
parent
e99bb2f368
commit
773a174534
31
sample.lua
31
sample.lua
@ -86,6 +86,7 @@ for L = 1,checkpoint.opt.num_layers do
|
|||||||
end
|
end
|
||||||
state_size = #current_state
|
state_size = #current_state
|
||||||
|
|
||||||
|
-- parse characters from a string
|
||||||
function get_char(str)
|
function get_char(str)
|
||||||
local len = #str
|
local len = #str
|
||||||
local left = 0
|
local left = 0
|
||||||
@ -108,6 +109,7 @@ function get_char(str)
|
|||||||
left = left + i
|
left = left + i
|
||||||
unordered[#unordered+1] = tmpString
|
unordered[#unordered+1] = tmpString
|
||||||
end
|
end
|
||||||
|
return unordered
|
||||||
end
|
end
|
||||||
|
|
||||||
-- do a few seeded timesteps
|
-- do a few seeded timesteps
|
||||||
@ -116,8 +118,7 @@ if string.len(seed_text) > 0 then
|
|||||||
gprint('seeding with ' .. seed_text)
|
gprint('seeding with ' .. seed_text)
|
||||||
gprint('--------------------------')
|
gprint('--------------------------')
|
||||||
local chars = get_char(seed_text)
|
local chars = get_char(seed_text)
|
||||||
print(chars)
|
for i,c in ipairs(chars) do
|
||||||
for i,c in ipairs(chars)'.' do
|
|
||||||
prev_char = torch.Tensor{vocab[c]}
|
prev_char = torch.Tensor{vocab[c]}
|
||||||
io.write(ivocab[prev_char[1]])
|
io.write(ivocab[prev_char[1]])
|
||||||
if opt.gpuid >= 0 then prev_char = prev_char:cuda() end
|
if opt.gpuid >= 0 then prev_char = prev_char:cuda() end
|
||||||
@ -139,17 +140,21 @@ end
|
|||||||
for i=1, opt.length do
|
for i=1, opt.length do
|
||||||
|
|
||||||
-- log probabilities from the previous timestep
|
-- log probabilities from the previous timestep
|
||||||
if opt.sample == 0 then
|
-- make sure the output char is not UNKNOW
|
||||||
-- use argmax
|
prev_char = 'UNKNOW'
|
||||||
local _, prev_char_ = prediction:max(2)
|
while(prev_char == 'UNKNOW') do
|
||||||
prev_char = prev_char_:resize(1)
|
if opt.sample == 0 then
|
||||||
else
|
-- use argmax
|
||||||
-- use sampling
|
local _, prev_char_ = prediction:max(2)
|
||||||
prediction:div(opt.temperature) -- scale by temperature
|
prev_char = prev_char_:resize(1)
|
||||||
local probs = torch.exp(prediction):squeeze()
|
else
|
||||||
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
|
-- use sampling
|
||||||
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
|
prediction:div(opt.temperature) -- scale by temperature
|
||||||
end
|
local probs = torch.exp(prediction):squeeze()
|
||||||
|
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
|
||||||
|
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
-- forward the rnn for next character
|
-- forward the rnn for next character
|
||||||
local lst = protos.rnn:forward{prev_char, unpack(current_state)}
|
local lst = protos.rnn:forward{prev_char, unpack(current_state)}
|
||||||
|
@ -32,6 +32,7 @@ cmd:text()
|
|||||||
cmd:text('Options')
|
cmd:text('Options')
|
||||||
-- data
|
-- data
|
||||||
cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain the file input.txt with input data')
|
cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain the file input.txt with input data')
|
||||||
|
cmd:option('-min_freq',0,'min frequent of character')
|
||||||
-- model params
|
-- model params
|
||||||
cmd:option('-rnn_size', 128, 'size of LSTM internal state')
|
cmd:option('-rnn_size', 128, 'size of LSTM internal state')
|
||||||
cmd:option('-num_layers', 2, 'number of layers in the LSTM')
|
cmd:option('-num_layers', 2, 'number of layers in the LSTM')
|
||||||
@ -59,7 +60,6 @@ cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be in
|
|||||||
-- GPU/CPU
|
-- GPU/CPU
|
||||||
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
|
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
|
||||||
cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
|
cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
|
||||||
cmd:option('-min_freq',0,'min frequent of character')
|
|
||||||
cmd:text()
|
cmd:text()
|
||||||
|
|
||||||
-- parse input params
|
-- parse input params
|
||||||
|
Loading…
Reference in New Issue
Block a user