2015-07-07 04:09:34 +00:00
-- Modified from https://github.com/oxford-cs-ml-2015/practical6
-- the modification included support for train/val/test splits
local CharSplitLMMinibatchLoader = { }
CharSplitLMMinibatchLoader.__index = CharSplitLMMinibatchLoader
function CharSplitLMMinibatchLoader . create ( data_dir , batch_size , seq_length , split_fractions , min_freq )
-- split_fractions is e.g. {0.9, 0.05, 0.05}
local self = { }
setmetatable ( self , CharSplitLMMinibatchLoader )
local input_file = path.join ( data_dir , ' input.txt ' )
local vocab_file = path.join ( data_dir , ' vocab.t7 ' )
local tensor_file = path.join ( data_dir , ' data.t7 ' )
-- fetch file attributes to determine if we need to rerun preprocessing
local run_prepro = false
if not ( path.exists ( vocab_file ) or path.exists ( tensor_file ) ) then
-- prepro files do not exist, generate them
print ( ' vocab.t7 and data.t7 do not exist. Running preprocessing... ' )
run_prepro = true
-- check if the input file was modified since last time we
-- ran the prepro. if so, we have to rerun the preprocessing
local input_attr = lfs.attributes ( input_file )
local vocab_attr = lfs.attributes ( vocab_file )
local tensor_attr = lfs.attributes ( tensor_file )
if input_attr.modification > vocab_attr.modification or input_attr.modification > tensor_attr.modification then
print ( ' vocab.t7 or data.t7 detected as stale. Re-running preprocessing... ' )
run_prepro = true
if run_prepro then
-- construct a tensor with all the data, and vocab file
print ( ' one-time setup: preprocessing input text file ' .. input_file .. ' ... ' )
CharSplitLMMinibatchLoader.text_to_tensor ( input_file , vocab_file , tensor_file , min_freq )
print ( ' loading data files... ' )
2018-07-10 11:32:52 +00:00
local data = torch.load ( tensor_file , ' ascii ' )
self.vocab_mapping = torch.load ( vocab_file , ' ascii ' )
2015-07-07 04:09:34 +00:00
-- cut off the end so that it divides evenly
local len = data : size ( 1 )
if len % ( batch_size * seq_length ) ~= 0 then
print ( ' cutting off end of data so that the batches/sequences divide evenly ' )
data = data : sub ( 1 , batch_size * seq_length
* math.floor ( len / ( batch_size * seq_length ) ) )
-- count vocab
self.vocab_size = 0
for _ in pairs ( self.vocab_mapping ) do
self.vocab_size = self.vocab_size + 1
-- self.batches is a table of tensors
print ( ' reshaping tensor... ' )
self.batch_size = batch_size
self.seq_length = seq_length
local ydata = data : clone ( )
ydata : sub ( 1 , - 2 ) : copy ( data : sub ( 2 , - 1 ) )
ydata [ - 1 ] = data [ 1 ]
self.x_batches = data : view ( batch_size , - 1 ) : split ( seq_length , 2 ) -- #rows = #batches
self.nbatches = # self.x_batches
self.y_batches = ydata : view ( batch_size , - 1 ) : split ( seq_length , 2 ) -- #rows = #batches
assert ( # self.x_batches == # self.y_batches )
-- lets try to be helpful here
if self.nbatches < 50 then
print ( ' WARNING: less than 50 batches in the data in total? Looks like very small dataset. You probably want to use smaller batch_size and/or seq_length. ' )
-- perform safety checks on split_fractions
assert ( split_fractions [ 1 ] >= 0 and split_fractions [ 1 ] <= 1 , ' bad split fraction ' .. split_fractions [ 1 ] .. ' for train, not between 0 and 1 ' )
assert ( split_fractions [ 2 ] >= 0 and split_fractions [ 2 ] <= 1 , ' bad split fraction ' .. split_fractions [ 2 ] .. ' for val, not between 0 and 1 ' )
assert ( split_fractions [ 3 ] >= 0 and split_fractions [ 3 ] <= 1 , ' bad split fraction ' .. split_fractions [ 3 ] .. ' for test, not between 0 and 1 ' )
if split_fractions [ 3 ] == 0 then
-- catch a common special case where the user might not want a test set
self.ntrain = math.floor ( self.nbatches * split_fractions [ 1 ] )
self.nval = self.nbatches - self.ntrain
self.ntest = 0
-- divide data to train/val and allocate rest to test
self.ntrain = math.floor ( self.nbatches * split_fractions [ 1 ] )
self.nval = math.floor ( self.nbatches * split_fractions [ 2 ] )
self.ntest = self.nbatches - self.nval - self.ntrain -- the rest goes to test (to ensure this adds up exactly)
self.split_sizes = { self.ntrain , self.nval , self.ntest }
self.batch_ix = { 0 , 0 , 0 }
print ( string.format ( ' data load done. Number of data batches in train: %d, val: %d, test: %d ' , self.ntrain , self.nval , self.ntest ) )
collectgarbage ( )
return self
function CharSplitLMMinibatchLoader : reset_batch_pointer ( split_index , batch_index )
batch_index = batch_index or 0
self.batch_ix [ split_index ] = batch_index
function CharSplitLMMinibatchLoader : next_batch ( split_index )
if self.split_sizes [ split_index ] == 0 then
-- perform a check here to make sure the user isn't screwing something up
local split_names = { ' train ' , ' val ' , ' test ' }
print ( ' ERROR. Code requested a batch for split ' .. split_names [ split_index ] .. ' , but this split has no data. ' )
os.exit ( ) -- crash violently
-- split_index is integer: 1 = train, 2 = val, 3 = test
self.batch_ix [ split_index ] = self.batch_ix [ split_index ] + 1
if self.batch_ix [ split_index ] > self.split_sizes [ split_index ] then
self.batch_ix [ split_index ] = 1 -- cycle around to beginning
-- pull out the correct next batch
local ix = self.batch_ix [ split_index ]
if split_index == 2 then ix = ix + self.ntrain end -- offset by train set size
if split_index == 3 then ix = ix + self.ntrain + self.nval end -- offset by train + val
return self.x_batches [ ix ] , self.y_batches [ ix ]
-- chinese vocab
function get_vocab ( str , min_freq )
local len = # str
local left = 0
local arr = { 0 , 0xc0 , 0xe0 , 0xf0 , 0xf8 , 0xfc }
local unordered = { }
local start = 1
local wordLen = 0
g_total_chars = 0
while len ~= left do
local tmp = string.byte ( str , start )
local i = # arr
while arr [ i ] do
if tmp >= arr [ i ] then
i = i - 1
wordLen = i + wordLen
local tmpString = string.sub ( str , start , wordLen )
start = start + i
left = left + i
if not unordered [ tmpString ] then
unordered [ tmpString ] = 1
unordered [ tmpString ] = unordered [ tmpString ] + 1 end
g_total_chars = g_total_chars + 1
final_res = { }
for char_val , char_cnt in pairs ( unordered ) do
if char_cnt >= min_freq then
final_res [ char_val ] = true
return final_res
-- change raw data to tokens
function get_data ( str , vocab_mapping )
-- can not use torch.ByteTensor because it support mo more than 256
-- local data = torch.ByteTensor(g_total_chars) -- store it into 1D first, then rearrange
local data = torch.ShortTensor ( g_total_chars )
local len = # str
local left = 0
local arr = { 0 , 0xc0 , 0xe0 , 0xf0 , 0xf8 , 0xfc }
local start = 1
local wordLen = 0
local count = 1
while len ~= left do
local tmp = string.byte ( str , start )
local i = # arr
while arr [ i ] do
if tmp >= arr [ i ] then
i = i - 1
wordLen = i + wordLen
local tmpString = string.sub ( str , start , wordLen )
start = start + i
left = left + i
if vocab_mapping [ tmpString ] then
data [ count ] = vocab_mapping [ tmpString ]
data [ count ] = vocab_mapping [ ' UNKNOW ' ]
count = count + 1
return data
-- *** STATIC method ***
function CharSplitLMMinibatchLoader . text_to_tensor ( in_textfile , out_vocabfile , out_tensorfile , min_freq )
local timer = torch.Timer ( )
print ( ' loading text file... ' )
local f = torch.DiskFile ( in_textfile )
local rawdata = f : readString ( ' *a ' ) -- NOTE: this reads the whole file at once
f : close ( )
-- create vocabulary if it doesn't exist yet
print ( ' creating vocabulary mapping... ' )
-- record all characters to a set
local unordered = get_vocab ( rawdata , min_freq )
-- sort into a table (i.e. keys become 1..N)
local ordered = { }
for char in pairs ( unordered ) do ordered [ # ordered + 1 ] = char end
table.sort ( ordered )
-- invert `ordered` to create the char->int mapping
local vocab_mapping = { }
count_vocab = 0
for i , char in ipairs ( ordered ) do
vocab_mapping [ char ] = i
count_vocab = count_vocab + 1
vocab_mapping [ ' UNKNOW ' ] = count_vocab + 1
-- construct a tensor with all the data
print ( ' putting data into tensor, it takes a lot of time... ' )
local data = get_data ( rawdata , vocab_mapping )
-- save output preprocessed files
print ( ' saving ' .. out_vocabfile )
2018-07-10 11:32:52 +00:00
torch.save ( out_vocabfile , vocab_mapping , ' ascii ' )
2015-07-07 04:09:34 +00:00
print ( ' saving ' .. out_tensorfile )
2018-07-10 11:32:52 +00:00
torch.save ( out_tensorfile , data , ' ascii ' )
2015-07-07 04:09:34 +00:00
return CharSplitLMMinibatchLoader