-- Modified from https://github.com/oxford-cs-ml-2015/practical6 -- the modification included support for train/val/test splits local CharSplitLMMinibatchLoader = {} CharSplitLMMinibatchLoader.__index = CharSplitLMMinibatchLoader function CharSplitLMMinibatchLoader.create(data_dir, batch_size, seq_length, split_fractions, min_freq) -- split_fractions is e.g. {0.9, 0.05, 0.05} local self = {} setmetatable(self, CharSplitLMMinibatchLoader) local input_file = path.join(data_dir, 'input.txt') local vocab_file = path.join(data_dir, 'vocab.t7') local tensor_file = path.join(data_dir, 'data.t7') -- fetch file attributes to determine if we need to rerun preprocessing local run_prepro = false if not (path.exists(vocab_file) or path.exists(tensor_file)) then -- prepro files do not exist, generate them print('vocab.t7 and data.t7 do not exist. Running preprocessing...') run_prepro = true else -- check if the input file was modified since last time we -- ran the prepro. if so, we have to rerun the preprocessing local input_attr = lfs.attributes(input_file) local vocab_attr = lfs.attributes(vocab_file) local tensor_attr = lfs.attributes(tensor_file) if input_attr.modification > vocab_attr.modification or input_attr.modification > tensor_attr.modification then print('vocab.t7 or data.t7 detected as stale. Re-running preprocessing...') run_prepro = true end end if run_prepro then -- construct a tensor with all the data, and vocab file print('one-time setup: preprocessing input text file ' .. input_file .. '...') CharSplitLMMinibatchLoader.text_to_tensor(input_file, vocab_file, tensor_file, min_freq) end print('loading data files...') local data = torch.load(tensor_file) self.vocab_mapping = torch.load(vocab_file) -- cut off the end so that it divides evenly local len = data:size(1) if len % (batch_size * seq_length) ~= 0 then print('cutting off end of data so that the batches/sequences divide evenly') data = data:sub(1, batch_size * seq_length * math.floor(len / (batch_size * seq_length))) end -- count vocab self.vocab_size = 0 for _ in pairs(self.vocab_mapping) do self.vocab_size = self.vocab_size + 1 end -- self.batches is a table of tensors print('reshaping tensor...') self.batch_size = batch_size self.seq_length = seq_length local ydata = data:clone() ydata:sub(1,-2):copy(data:sub(2,-1)) ydata[-1] = data[1] self.x_batches = data:view(batch_size, -1):split(seq_length, 2) -- #rows = #batches self.nbatches = #self.x_batches self.y_batches = ydata:view(batch_size, -1):split(seq_length, 2) -- #rows = #batches assert(#self.x_batches == #self.y_batches) -- lets try to be helpful here if self.nbatches < 50 then print('WARNING: less than 50 batches in the data in total? Looks like very small dataset. You probably want to use smaller batch_size and/or seq_length.') end -- perform safety checks on split_fractions assert(split_fractions[1] >= 0 and split_fractions[1] <= 1, 'bad split fraction ' .. split_fractions[1] .. ' for train, not between 0 and 1') assert(split_fractions[2] >= 0 and split_fractions[2] <= 1, 'bad split fraction ' .. split_fractions[2] .. ' for val, not between 0 and 1') assert(split_fractions[3] >= 0 and split_fractions[3] <= 1, 'bad split fraction ' .. split_fractions[3] .. ' for test, not between 0 and 1') if split_fractions[3] == 0 then -- catch a common special case where the user might not want a test set self.ntrain = math.floor(self.nbatches * split_fractions[1]) self.nval = self.nbatches - self.ntrain self.ntest = 0 else -- divide data to train/val and allocate rest to test self.ntrain = math.floor(self.nbatches * split_fractions[1]) self.nval = math.floor(self.nbatches * split_fractions[2]) self.ntest = self.nbatches - self.nval - self.ntrain -- the rest goes to test (to ensure this adds up exactly) end self.split_sizes = {self.ntrain, self.nval, self.ntest} self.batch_ix = {0,0,0} print(string.format('data load done. Number of data batches in train: %d, val: %d, test: %d', self.ntrain, self.nval, self.ntest)) collectgarbage() return self end function CharSplitLMMinibatchLoader:reset_batch_pointer(split_index, batch_index) batch_index = batch_index or 0 self.batch_ix[split_index] = batch_index end function CharSplitLMMinibatchLoader:next_batch(split_index) if self.split_sizes[split_index] == 0 then -- perform a check here to make sure the user isn't screwing something up local split_names = {'train', 'val', 'test'} print('ERROR. Code requested a batch for split ' .. split_names[split_index] .. ', but this split has no data.') os.exit() -- crash violently end -- split_index is integer: 1 = train, 2 = val, 3 = test self.batch_ix[split_index] = self.batch_ix[split_index] + 1 if self.batch_ix[split_index] > self.split_sizes[split_index] then self.batch_ix[split_index] = 1 -- cycle around to beginning end -- pull out the correct next batch local ix = self.batch_ix[split_index] if split_index == 2 then ix = ix + self.ntrain end -- offset by train set size if split_index == 3 then ix = ix + self.ntrain + self.nval end -- offset by train + val return self.x_batches[ix], self.y_batches[ix] end -- chinese vocab function get_vocab(str, min_freq) local len = #str local left = 0 local arr = {0, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc} local unordered = {} local start = 1 local wordLen = 0 g_total_chars = 0 while len ~= left do local tmp = string.byte(str, start) local i = #arr while arr[i] do if tmp >= arr[i] then break end i = i - 1 end wordLen = i + wordLen local tmpString = string.sub(str, start, wordLen) start = start + i left = left + i if not unordered[tmpString] then unordered[tmpString] = 1 else unordered[tmpString] = unordered[tmpString] + 1 end g_total_chars = g_total_chars + 1 end final_res = {} for char_val, char_cnt in pairs(unordered) do if char_cnt >= min_freq then final_res[char_val] = true end end return final_res end -- change raw data to tokens function get_data(str, vocab_mapping) -- can not use torch.ByteTensor because it support mo more than 256 -- local data = torch.ByteTensor(g_total_chars) -- store it into 1D first, then rearrange local data = torch.ShortTensor(g_total_chars) local len = #str local left = 0 local arr = {0, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc} local start = 1 local wordLen = 0 local count = 1 while len ~= left do local tmp = string.byte(str, start) local i = #arr while arr[i] do if tmp >= arr[i] then break end i = i - 1 end wordLen = i + wordLen local tmpString = string.sub(str, start, wordLen) start = start + i left = left + i if vocab_mapping[tmpString] then data[count] = vocab_mapping[tmpString] else data[count] = vocab_mapping['UNKNOW'] end count = count + 1 end return data end -- *** STATIC method *** function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, out_tensorfile, min_freq) local timer = torch.Timer() print('loading text file...') local f = torch.DiskFile(in_textfile) local rawdata = f:readString('*a') -- NOTE: this reads the whole file at once f:close() -- create vocabulary if it doesn't exist yet print('creating vocabulary mapping...') -- record all characters to a set local unordered = get_vocab(rawdata, min_freq) -- sort into a table (i.e. keys become 1..N) local ordered = {} for char in pairs(unordered) do ordered[#ordered + 1] = char end table.sort(ordered) -- invert `ordered` to create the char->int mapping local vocab_mapping = {} count_vocab = 0 for i, char in ipairs(ordered) do vocab_mapping[char] = i count_vocab = count_vocab + 1 end vocab_mapping['UNKNOW'] = count_vocab + 1 -- construct a tensor with all the data print('putting data into tensor, it takes a lot of time...') local data = get_data(rawdata, vocab_mapping) -- save output preprocessed files print('saving ' .. out_vocabfile) torch.save(out_vocabfile, vocab_mapping) print('saving ' .. out_tensorfile) torch.save(out_tensorfile, data) end return CharSplitLMMinibatchLoader