236 lines
8.7 KiB
Lua
236 lines
8.7 KiB
Lua
|
|
||
|
-- Modified from https://github.com/oxford-cs-ml-2015/practical6
|
||
|
-- the modification included support for train/val/test splits
|
||
|
|
||
|
local CharSplitLMMinibatchLoader = {}
|
||
|
CharSplitLMMinibatchLoader.__index = CharSplitLMMinibatchLoader
|
||
|
|
||
|
function CharSplitLMMinibatchLoader.create(data_dir, batch_size, seq_length, split_fractions, min_freq)
|
||
|
-- split_fractions is e.g. {0.9, 0.05, 0.05}
|
||
|
|
||
|
local self = {}
|
||
|
setmetatable(self, CharSplitLMMinibatchLoader)
|
||
|
|
||
|
local input_file = path.join(data_dir, 'input.txt')
|
||
|
local vocab_file = path.join(data_dir, 'vocab.t7')
|
||
|
local tensor_file = path.join(data_dir, 'data.t7')
|
||
|
|
||
|
-- fetch file attributes to determine if we need to rerun preprocessing
|
||
|
local run_prepro = false
|
||
|
if not (path.exists(vocab_file) or path.exists(tensor_file)) then
|
||
|
-- prepro files do not exist, generate them
|
||
|
print('vocab.t7 and data.t7 do not exist. Running preprocessing...')
|
||
|
run_prepro = true
|
||
|
else
|
||
|
-- check if the input file was modified since last time we
|
||
|
-- ran the prepro. if so, we have to rerun the preprocessing
|
||
|
local input_attr = lfs.attributes(input_file)
|
||
|
local vocab_attr = lfs.attributes(vocab_file)
|
||
|
local tensor_attr = lfs.attributes(tensor_file)
|
||
|
if input_attr.modification > vocab_attr.modification or input_attr.modification > tensor_attr.modification then
|
||
|
print('vocab.t7 or data.t7 detected as stale. Re-running preprocessing...')
|
||
|
run_prepro = true
|
||
|
end
|
||
|
end
|
||
|
if run_prepro then
|
||
|
-- construct a tensor with all the data, and vocab file
|
||
|
print('one-time setup: preprocessing input text file ' .. input_file .. '...')
|
||
|
CharSplitLMMinibatchLoader.text_to_tensor(input_file, vocab_file, tensor_file, min_freq)
|
||
|
end
|
||
|
|
||
|
print('loading data files...')
|
||
|
local data = torch.load(tensor_file)
|
||
|
self.vocab_mapping = torch.load(vocab_file)
|
||
|
|
||
|
-- cut off the end so that it divides evenly
|
||
|
local len = data:size(1)
|
||
|
if len % (batch_size * seq_length) ~= 0 then
|
||
|
print('cutting off end of data so that the batches/sequences divide evenly')
|
||
|
data = data:sub(1, batch_size * seq_length
|
||
|
* math.floor(len / (batch_size * seq_length)))
|
||
|
end
|
||
|
|
||
|
-- count vocab
|
||
|
self.vocab_size = 0
|
||
|
for _ in pairs(self.vocab_mapping) do
|
||
|
self.vocab_size = self.vocab_size + 1
|
||
|
end
|
||
|
|
||
|
-- self.batches is a table of tensors
|
||
|
print('reshaping tensor...')
|
||
|
self.batch_size = batch_size
|
||
|
self.seq_length = seq_length
|
||
|
|
||
|
local ydata = data:clone()
|
||
|
ydata:sub(1,-2):copy(data:sub(2,-1))
|
||
|
ydata[-1] = data[1]
|
||
|
self.x_batches = data:view(batch_size, -1):split(seq_length, 2) -- #rows = #batches
|
||
|
self.nbatches = #self.x_batches
|
||
|
self.y_batches = ydata:view(batch_size, -1):split(seq_length, 2) -- #rows = #batches
|
||
|
assert(#self.x_batches == #self.y_batches)
|
||
|
|
||
|
-- lets try to be helpful here
|
||
|
if self.nbatches < 50 then
|
||
|
print('WARNING: less than 50 batches in the data in total? Looks like very small dataset. You probably want to use smaller batch_size and/or seq_length.')
|
||
|
end
|
||
|
|
||
|
-- perform safety checks on split_fractions
|
||
|
assert(split_fractions[1] >= 0 and split_fractions[1] <= 1, 'bad split fraction ' .. split_fractions[1] .. ' for train, not between 0 and 1')
|
||
|
assert(split_fractions[2] >= 0 and split_fractions[2] <= 1, 'bad split fraction ' .. split_fractions[2] .. ' for val, not between 0 and 1')
|
||
|
assert(split_fractions[3] >= 0 and split_fractions[3] <= 1, 'bad split fraction ' .. split_fractions[3] .. ' for test, not between 0 and 1')
|
||
|
if split_fractions[3] == 0 then
|
||
|
-- catch a common special case where the user might not want a test set
|
||
|
self.ntrain = math.floor(self.nbatches * split_fractions[1])
|
||
|
self.nval = self.nbatches - self.ntrain
|
||
|
self.ntest = 0
|
||
|
else
|
||
|
-- divide data to train/val and allocate rest to test
|
||
|
self.ntrain = math.floor(self.nbatches * split_fractions[1])
|
||
|
self.nval = math.floor(self.nbatches * split_fractions[2])
|
||
|
self.ntest = self.nbatches - self.nval - self.ntrain -- the rest goes to test (to ensure this adds up exactly)
|
||
|
end
|
||
|
|
||
|
self.split_sizes = {self.ntrain, self.nval, self.ntest}
|
||
|
self.batch_ix = {0,0,0}
|
||
|
|
||
|
print(string.format('data load done. Number of data batches in train: %d, val: %d, test: %d', self.ntrain, self.nval, self.ntest))
|
||
|
collectgarbage()
|
||
|
return self
|
||
|
end
|
||
|
|
||
|
function CharSplitLMMinibatchLoader:reset_batch_pointer(split_index, batch_index)
|
||
|
batch_index = batch_index or 0
|
||
|
self.batch_ix[split_index] = batch_index
|
||
|
end
|
||
|
|
||
|
function CharSplitLMMinibatchLoader:next_batch(split_index)
|
||
|
if self.split_sizes[split_index] == 0 then
|
||
|
-- perform a check here to make sure the user isn't screwing something up
|
||
|
local split_names = {'train', 'val', 'test'}
|
||
|
print('ERROR. Code requested a batch for split ' .. split_names[split_index] .. ', but this split has no data.')
|
||
|
os.exit() -- crash violently
|
||
|
end
|
||
|
-- split_index is integer: 1 = train, 2 = val, 3 = test
|
||
|
self.batch_ix[split_index] = self.batch_ix[split_index] + 1
|
||
|
if self.batch_ix[split_index] > self.split_sizes[split_index] then
|
||
|
self.batch_ix[split_index] = 1 -- cycle around to beginning
|
||
|
end
|
||
|
-- pull out the correct next batch
|
||
|
local ix = self.batch_ix[split_index]
|
||
|
if split_index == 2 then ix = ix + self.ntrain end -- offset by train set size
|
||
|
if split_index == 3 then ix = ix + self.ntrain + self.nval end -- offset by train + val
|
||
|
return self.x_batches[ix], self.y_batches[ix]
|
||
|
end
|
||
|
|
||
|
-- chinese vocab
|
||
|
function get_vocab(str, min_freq)
|
||
|
local len = #str
|
||
|
local left = 0
|
||
|
local arr = {0, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc}
|
||
|
local unordered = {}
|
||
|
local start = 1
|
||
|
local wordLen = 0
|
||
|
g_total_chars = 0
|
||
|
while len ~= left do
|
||
|
local tmp = string.byte(str, start)
|
||
|
local i = #arr
|
||
|
while arr[i] do
|
||
|
if tmp >= arr[i] then
|
||
|
break
|
||
|
end
|
||
|
i = i - 1
|
||
|
end
|
||
|
wordLen = i + wordLen
|
||
|
local tmpString = string.sub(str, start, wordLen)
|
||
|
start = start + i
|
||
|
left = left + i
|
||
|
if not unordered[tmpString] then
|
||
|
unordered[tmpString] = 1
|
||
|
else
|
||
|
unordered[tmpString] = unordered[tmpString] + 1 end
|
||
|
g_total_chars = g_total_chars + 1
|
||
|
end
|
||
|
final_res = {}
|
||
|
for char_val, char_cnt in pairs(unordered) do
|
||
|
if char_cnt >= min_freq then
|
||
|
final_res[char_val] = true
|
||
|
end
|
||
|
end
|
||
|
return final_res
|
||
|
end
|
||
|
|
||
|
-- change raw data to tokens
|
||
|
function get_data(str, vocab_mapping)
|
||
|
-- can not use torch.ByteTensor because it support mo more than 256
|
||
|
-- local data = torch.ByteTensor(g_total_chars) -- store it into 1D first, then rearrange
|
||
|
local data = torch.ShortTensor(g_total_chars)
|
||
|
local len = #str
|
||
|
local left = 0
|
||
|
local arr = {0, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc}
|
||
|
local start = 1
|
||
|
local wordLen = 0
|
||
|
local count = 1
|
||
|
while len ~= left do
|
||
|
local tmp = string.byte(str, start)
|
||
|
local i = #arr
|
||
|
while arr[i] do
|
||
|
if tmp >= arr[i] then
|
||
|
break
|
||
|
end
|
||
|
i = i - 1
|
||
|
end
|
||
|
wordLen = i + wordLen
|
||
|
local tmpString = string.sub(str, start, wordLen)
|
||
|
start = start + i
|
||
|
left = left + i
|
||
|
if vocab_mapping[tmpString] then
|
||
|
data[count] = vocab_mapping[tmpString]
|
||
|
else
|
||
|
data[count] = vocab_mapping['UNKNOW']
|
||
|
end
|
||
|
count = count + 1
|
||
|
end
|
||
|
return data
|
||
|
end
|
||
|
|
||
|
-- *** STATIC method ***
|
||
|
function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, out_tensorfile, min_freq)
|
||
|
local timer = torch.Timer()
|
||
|
|
||
|
print('loading text file...')
|
||
|
local f = torch.DiskFile(in_textfile)
|
||
|
local rawdata = f:readString('*a') -- NOTE: this reads the whole file at once
|
||
|
f:close()
|
||
|
|
||
|
-- create vocabulary if it doesn't exist yet
|
||
|
print('creating vocabulary mapping...')
|
||
|
-- record all characters to a set
|
||
|
local unordered = get_vocab(rawdata, min_freq)
|
||
|
-- sort into a table (i.e. keys become 1..N)
|
||
|
local ordered = {}
|
||
|
for char in pairs(unordered) do ordered[#ordered + 1] = char end
|
||
|
table.sort(ordered)
|
||
|
-- invert `ordered` to create the char->int mapping
|
||
|
local vocab_mapping = {}
|
||
|
|
||
|
count_vocab = 0
|
||
|
for i, char in ipairs(ordered) do
|
||
|
vocab_mapping[char] = i
|
||
|
count_vocab = count_vocab + 1
|
||
|
end
|
||
|
vocab_mapping['UNKNOW'] = count_vocab + 1
|
||
|
|
||
|
-- construct a tensor with all the data
|
||
|
print('putting data into tensor, it takes a lot of time...')
|
||
|
local data = get_data(rawdata, vocab_mapping)
|
||
|
|
||
|
-- save output preprocessed files
|
||
|
print('saving ' .. out_vocabfile)
|
||
|
torch.save(out_vocabfile, vocab_mapping)
|
||
|
print('saving ' .. out_tensorfile)
|
||
|
torch.save(out_tensorfile, data)
|
||
|
end
|
||
|
|
||
|
return CharSplitLMMinibatchLoader
|
||
|
|