char-rnn-chinese/sample.lua

199 lines
6.0 KiB
Lua
Raw Permalink Normal View History

--[[
This file samples characters from a trained model
Code is based on implementation in
https://github.com/oxford-cs-ml-2015/practical6
]]--
require 'torch'
require 'nn'
require 'nngraph'
require 'optim'
require 'lfs'
require 'util.OneHot'
require 'util.misc'
cmd = torch.CmdLine()
cmd:text()
cmd:text('Sample from a character-level language model')
cmd:text()
cmd:text('Options')
-- required:
cmd:argument('-model','model checkpoint to use for sampling')
-- optional parameters
cmd:option('-seed',123,'random number generator\'s seed')
cmd:option('-sample',1,' 0 to use max at each timestep, 1 to sample at each timestep')
cmd:option('-primetext',"",'used as a prompt to "seed" the state of the LSTM using a given sequence, before we sample.')
2015-07-15 04:15:45 +00:00
cmd:option('-length',2000,'max number of characters to sample')
cmd:option('-temperature',1,'temperature of sampling')
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
cmd:option('-verbose',1,'set to 0 to ONLY print the sampled text, no diagnostics')
2015-07-15 05:11:05 +00:00
cmd:option('-stop','\n\n\n\n\n','stop sampling when detected')
cmd:text()
2018-07-10 11:32:52 +00:00
-- parse input params
opt = cmd:parse(arg)
2018-07-10 11:32:52 +00:00
fileInput = io.open("D:\\Work\\nodejsprojects\\rnn-chat-backend\\input.txt", 'r')
oneLine = fileInput:read()
opt.primetext = string.gsub(oneLine, "\n", "")
fileInput:close()
--opt.primetext = "恍恍惚惚"
-- gated print: simple utility function wrapping a print
function gprint(str)
if opt.verbose == 1 then print(str) end
end
-- check that cunn/cutorch are installed if user wants to use the GPU
if opt.gpuid >= 0 then
local ok, cunn = pcall(require, 'cunn')
local ok2, cutorch = pcall(require, 'cutorch')
if not ok then gprint('package cunn not found!') end
if not ok2 then gprint('package cutorch not found!') end
if ok and ok2 then
gprint('using CUDA on GPU ' .. opt.gpuid .. '...')
cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua
cutorch.manualSeed(opt.seed)
else
gprint('Falling back on CPU mode')
opt.gpuid = -1 -- overwrite user setting
end
end
torch.manualSeed(opt.seed)
-- load the model checkpoint
if not lfs.attributes(opt.model, 'mode') then
gprint('Error: File ' .. opt.model .. ' does not exist. Are you sure you didn\'t forget to prepend cv/ ?')
end
2018-07-10 11:32:52 +00:00
checkpoint = torch.load(opt.model, 'ascii')
protos = checkpoint.protos
protos.rnn:evaluate() -- put in eval mode so that dropout works properly
-- initialize the vocabulary (and its inverted version)
local vocab = checkpoint.vocab
local ivocab = {}
for c,i in pairs(vocab) do ivocab[i] = c end
-- initialize the rnn state to all zeros
gprint('creating an LSTM...')
local current_state
local num_layers = checkpoint.opt.num_layers
current_state = {}
for L = 1,checkpoint.opt.num_layers do
-- c and h for all layers
local h_init = torch.zeros(1, checkpoint.opt.rnn_size):double()
if opt.gpuid >= 0 then h_init = h_init:cuda() end
table.insert(current_state, h_init:clone())
table.insert(current_state, h_init:clone())
end
state_size = #current_state
2015-07-15 03:24:21 +00:00
-- parse characters from a string
function get_char(str)
local len = #str
local left = 0
local arr = {0, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc}
local unordered = {}
local start = 1
local wordLen = 0
while len ~= left do
local tmp = string.byte(str, start)
local i = #arr
while arr[i] do
if tmp >= arr[i] then
break
end
i = i - 1
end
wordLen = i + wordLen
local tmpString = string.sub(str, start, wordLen)
start = start + i
left = left + i
unordered[#unordered+1] = tmpString
end
2018-07-10 11:32:52 +00:00
print("get_char wordlen" .. wordLen)
2015-07-15 03:24:21 +00:00
return unordered
end
-- do a few seeded timesteps
local seed_text = opt.primetext
if string.len(seed_text) > 0 then
gprint('seeding with ' .. seed_text)
gprint('--------------------------')
local chars = get_char(seed_text)
2015-07-15 03:24:21 +00:00
for i,c in ipairs(chars) do
prev_char = torch.Tensor{vocab[c]}
io.write(ivocab[prev_char[1]])
if opt.gpuid >= 0 then prev_char = prev_char:cuda() end
local lst = protos.rnn:forward{prev_char, unpack(current_state)}
-- lst is a list of [state1,state2,..stateN,output]. We want everything but last piece
current_state = {}
for i=1,state_size do table.insert(current_state, lst[i]) end
prediction = lst[#lst] -- last element holds the log probabilities
end
else
-- fill with uniform probabilities over characters (? hmm)
gprint('missing seed text, using uniform probability over first character')
gprint('--------------------------')
prediction = torch.Tensor(1, #ivocab):fill(1)/(#ivocab)
if opt.gpuid >= 0 then prediction = prediction:cuda() end
end
-- start sampling/argmaxing
2015-07-15 04:15:45 +00:00
result = ''
for i=1, opt.length do
-- log probabilities from the previous timestep
2015-07-15 03:24:21 +00:00
-- make sure the output char is not UNKNOW
2015-07-15 04:15:45 +00:00
if opt.sample == 0 then
-- use argmax
local _, prev_char_ = prediction:max(2)
prev_char = prev_char_:resize(1)
else
-- use sampling
real_char = 'UNKNOW'
while(real_char == 'UNKNOW') do
2015-07-15 03:24:21 +00:00
prediction:div(opt.temperature) -- scale by temperature
local probs = torch.exp(prediction):squeeze()
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
2015-07-15 04:15:45 +00:00
real_char = ivocab[prev_char[1]]
2015-07-15 03:24:21 +00:00
end
end
-- forward the rnn for next character
local lst = protos.rnn:forward{prev_char, unpack(current_state)}
current_state = {}
for i=1,state_size do table.insert(current_state, lst[i]) end
prediction = lst[#lst] -- last element holds the log probabilities
2015-07-15 04:15:45 +00:00
-- io.write(ivocab[prev_char[1]])
result = result .. ivocab[prev_char[1]]
-- in my data, five \n represent the end of each document
-- so count \n to stop sampling
if string.find(result, opt.stop) then break end
end
2015-07-15 04:15:45 +00:00
io.write(result)
2018-07-10 11:32:52 +00:00
fd = io.open('C:\\Users\\MACHENIKE\\Downloads\\char-rnn-chinese-master\\temp.txt', 'w')
fd:write(result)
fd:close()
io.write('\n') io.flush()