diff --git a/Readme.md b/Readme.md index 57254ff..ceee560 100644 --- a/Readme.md +++ b/Readme.md @@ -2,14 +2,41 @@ # char-rnn-chinese Based on https://github.com/karpathy/char-rnn. make the code work well with Chinese. +## Chinese process Make the code can process both English and Chinese characters. This is my first touch of Lua, so the string process seems silly, but it works well. +## opt.min_freq I also add an option called 'min_freq' because the vocab size in Chinese is very big, which makes the parameter num increase a lot. So delete some rare character may help. +## web interface +A web demo is added for others to test model easily, based on sub/pub of redis. +I use redis because i can't found some good RPC or WebServer work well integrated with Torch. +You should notice that the demo is async by ajax. To setup the demo on ubuntu: +Install redis and start it +```bash +$ wget http://download.redis.io/releases/redis-3.0.3.tar.gz +$ tar xzf redis-3.0.3.tar.gz +$ cd redis-3.0.3 +$ make +$ sudo make install +$ redis-server & +``` +Then install flask and the redis plugin for python: +```bash +$ sudo pip install flask +$ sudo pip install redis +``` +Put you model file in online_model, rename it as 'model.t7', the start the backend and fontend script: +```bash +$ nohup th web_backend.lua & +$ nohup python web_server.py & +``` + ----------------------------------------------- -Karpathy's raw Readme, please follow this to setup your experiment. +## Karpathy's raw Readme +please follow this to setup your experiment. This code implements **multi-layer Recurrent Neural Network** (RNN, LSTM, and GRU) for training/sampling from character-level language models. The model learns to predict the probability of the next character in a sequence. In other words, the input is a single text file and the model learns to generate text like it. diff --git a/templates/main.html b/templates/main.html new file mode 100644 index 0000000..7860e23 --- /dev/null +++ b/templates/main.html @@ -0,0 +1,91 @@ + + + + char-rnn API + + + + + + + + + + + + +
+ + + + + + + +
+ +

+ + + +
+ + diff --git a/web_backend.lua b/web_backend.lua new file mode 100644 index 0000000..7eaf133 --- /dev/null +++ b/web_backend.lua @@ -0,0 +1,145 @@ +require 'torch' +require 'nngraph' +require 'optim' +require 'lfs' +require 'nn' + +require 'util.OneHot' +require 'util.misc' +JSON = (loadfile "util/JSON.lua")() + + +local redis = require 'redis' +local client = redis.connect('127.0.0.1', 6379) +local client2 = redis.connect('127.0.0.1', 6379) +local channels = {'cv_channel'} +local model_file = './onlie_model/model.t7' +local gpuid = 0 +local seed = 123 + +-- check that cunn/cutorch are installed if user wants to use the GPU +if gpuid >= 0 then + local ok, cunn = pcall(require, 'cunn') + local ok2, cutorch = pcall(require, 'cutorch') + if not ok then print('package cunn not found!') end + if not ok2 then print('package cutorch not found!') end + if ok and ok2 then + print('using CUDA on GPU ' .. gpuid .. '...') + cutorch.setDevice(gpuid + 1) -- note +1 to make it 0 indexed! sigh lua + cutorch.manualSeed(seed) + else + print('Falling back on CPU mode') + gpuid = -1 -- overwrite user setting + end +end + +if not lfs.attributes(model_file, 'mode') then + print('Error: File ' .. model_file .. ' does not exist.') +end +checkpoint = torch.load(model_file) +protos = checkpoint.protos +protos.rnn:evaluate() -- put in eval mode so that dropout works properly + +-- initialize the vocabulary (and its inverted version) +local vocab = checkpoint.vocab +local ivocab = {} +for c,i in pairs(vocab) do ivocab[i] = c end + +-- parse characters from a string +function get_char(str) + local len = #str + local left = 0 + local arr = {0, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc} + local unordered = {} + local start = 1 + local wordLen = 0 + while len ~= left do + local tmp = string.byte(str, start) + local i = #arr + while arr[i] do + if tmp >= arr[i] then + break + end + i = i - 1 + end + wordLen = i + wordLen + local tmpString = string.sub(str, start, wordLen) + start = start + i + left = left + i + unordered[#unordered+1] = tmpString + end + return unordered +end + +-- start listen +for msg in client:pubsub({subscribe = channels}) do + if msg.kind == 'subscribe' then + print('Subscribed to channel '..msg.channel) + elseif msg.kind == 'message' then + -- print('Received the following message from '..msg.channel.."\n "..msg.payload.."\n") + local req = JSON:decode(msg.payload) + local primetext = '|' .. req['text'] .. '| ' + local session_id = req['sid'] + local seed = req['seed'] + local temperature = req['temp'] + + -- initialize the rnn state to all zeros + local current_state + local num_layers = checkpoint.opt.num_layers + current_state = {} + for L = 1,checkpoint.opt.num_layers do + -- c and h for all layers + local h_init = torch.zeros(1, checkpoint.opt.rnn_size) + if gpuid >= 0 then h_init = h_init:cuda() end + table.insert(current_state, h_init:clone()) + table.insert(current_state, h_init:clone()) + end + state_size = #current_state + + -- use input to init state + torch.manualSeed(seed) + for i,c in ipairs(get_char(primetext)) do + prev_char = vocab[c] + if prev_char then + prev_char = torch.Tensor{vocab[c]} + io.write(ivocab[prev_char[1]]) + if gpuid >= 0 then prev_char = prev_char:cuda() end + local lst = protos.rnn:forward{prev_char, unpack(current_state)} + -- lst is a list of [state1,state2,..stateN,output]. We want everything but last piece + current_state = {} + for i=1,state_size do table.insert(current_state, lst[i]) end + prediction = lst[#lst] -- last element holds the log probabilities + end + end + -- start sampling/argmaxing + result = '' + not_end = true + for i=1,1000 do + -- log probabilities from the previous timestep + -- make sure the output char is not UNKNOW + real_char = 'UNKNOW' + while(real_char == 'UNKNOW') do + torch.manualSeed(seed+1) + prediction:div(temperature) -- scale by temperature + local probs = torch.exp(prediction):squeeze() + probs:div(torch.sum(probs)) -- renormalize so probs sum to one + prev_char = torch.multinomial(probs:float(), 1):resize(1):float() + real_char = ivocab[prev_char[1]] + end + + -- forward the rnn for next character + local lst = protos.rnn:forward{prev_char, unpack(current_state)} + current_state = {} + for i=1,state_size do table.insert(current_state, lst[i]) end + prediction = lst[#lst] -- last element holds the log probabilities + result = result .. ivocab[prev_char[1]] + if string.find(result, '\n\n\n\n\n') then + not_end = false + break + end + end + if not_end then result = result .. '……' end + -- client2:set(session_id, result) + client2:setex(session_id, 100, result) + end +end diff --git a/web_server.py b/web_server.py new file mode 100644 index 0000000..76e2a7a --- /dev/null +++ b/web_server.py @@ -0,0 +1,51 @@ +#!/usr/bin/python +#encoding=utf-8 +import sys +reload(sys) +sys.setdefaultencoding('utf8') + +from flask import Flask +from flask import jsonify,render_template,request,abort +import redis +import time +import json +import hashlib + +app = Flask(__name__) +channel_name = 'cv_channel' + +@app.route('/') +def index(): + return render_template('main.html') + +@app.route('/api', methods=['POST']) +def api(): + if not request.json or not 'primetext' in request.json: + abort(400) + req = {} + req['text'] = request.json['primetext'] + req['temp'] = request.json['temperature'] + req['seed'] = request.json['seed'] + m = hashlib.md5() + m.update(str(time.time())) + req['sid'] = m.hexdigest() + + r = redis.StrictRedis(host='localhost', port=6379, db=0) + res = r.publish(channel_name, json.dumps(req)) + print res + if res == 0: + req['sid'] = 0 + + return jsonify({'sid': req['sid']}), 200 + +@app.route('/res', methods=['POST']) +def res(): + r = redis.StrictRedis(host='localhost', port=6379, db=0) + sid = request.json['sid'] + responds = r.get(sid) + if responds is None: + responds = '0' + return jsonify({'responds': responds}), 200 + +if __name__ == "__main__": + app.run(host='0.0.0.0', port=8080, debug=True)