add web demo
This commit is contained in:
parent
77b52e630b
commit
329021ad69
29
Readme.md
29
Readme.md
@ -2,14 +2,41 @@
|
|||||||
# char-rnn-chinese
|
# char-rnn-chinese
|
||||||
Based on https://github.com/karpathy/char-rnn. make the code work well with Chinese.
|
Based on https://github.com/karpathy/char-rnn. make the code work well with Chinese.
|
||||||
|
|
||||||
|
## Chinese process
|
||||||
Make the code can process both English and Chinese characters.
|
Make the code can process both English and Chinese characters.
|
||||||
This is my first touch of Lua, so the string process seems silly, but it works well.
|
This is my first touch of Lua, so the string process seems silly, but it works well.
|
||||||
|
|
||||||
|
## opt.min_freq
|
||||||
I also add an option called 'min_freq' because the vocab size in Chinese is very big, which makes the parameter num increase a lot.
|
I also add an option called 'min_freq' because the vocab size in Chinese is very big, which makes the parameter num increase a lot.
|
||||||
So delete some rare character may help.
|
So delete some rare character may help.
|
||||||
|
|
||||||
|
## web interface
|
||||||
|
A web demo is added for others to test model easily, based on sub/pub of redis.
|
||||||
|
I use redis because i can't found some good RPC or WebServer work well integrated with Torch.
|
||||||
|
You should notice that the demo is async by ajax. To setup the demo on ubuntu:
|
||||||
|
Install redis and start it
|
||||||
|
```bash
|
||||||
|
$ wget http://download.redis.io/releases/redis-3.0.3.tar.gz
|
||||||
|
$ tar xzf redis-3.0.3.tar.gz
|
||||||
|
$ cd redis-3.0.3
|
||||||
|
$ make
|
||||||
|
$ sudo make install
|
||||||
|
$ redis-server &
|
||||||
|
```
|
||||||
|
Then install flask and the redis plugin for python:
|
||||||
|
```bash
|
||||||
|
$ sudo pip install flask
|
||||||
|
$ sudo pip install redis
|
||||||
|
```
|
||||||
|
Put you model file in online_model, rename it as 'model.t7', the start the backend and fontend script:
|
||||||
|
```bash
|
||||||
|
$ nohup th web_backend.lua &
|
||||||
|
$ nohup python web_server.py &
|
||||||
|
```
|
||||||
|
|
||||||
-----------------------------------------------
|
-----------------------------------------------
|
||||||
Karpathy's raw Readme, please follow this to setup your experiment.
|
## Karpathy's raw Readme
|
||||||
|
please follow this to setup your experiment.
|
||||||
|
|
||||||
This code implements **multi-layer Recurrent Neural Network** (RNN, LSTM, and GRU) for training/sampling from character-level language models. The model learns to predict the probability of the next character in a sequence. In other words, the input is a single text file and the model learns to generate text like it.
|
This code implements **multi-layer Recurrent Neural Network** (RNN, LSTM, and GRU) for training/sampling from character-level language models. The model learns to predict the probability of the next character in a sequence. In other words, the input is a single text file and the model learns to generate text like it.
|
||||||
|
|
||||||
|
91
templates/main.html
Normal file
91
templates/main.html
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>char-rnn API</title>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<meta content="initial-scale=1, minimum-scale=1, width=device-width" name="viewport">
|
||||||
|
<script src="http://cdn.bootcss.com/jquery/2.1.4/jquery.min.js"></script>
|
||||||
|
<link href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.4/css/bootstrap.min.css" rel="stylesheet">
|
||||||
|
<script src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.4/js/bootstrap.min.js"></script>
|
||||||
|
|
||||||
|
<style>
|
||||||
|
body{ padding:20px; padding-top:0px;}
|
||||||
|
#form_net_sample{max-width:650px;margin-right:auto;margin-left:auto;}
|
||||||
|
.description{font-weight:200;font-size:13px;}
|
||||||
|
label{margin-top:5px;}
|
||||||
|
</style>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
function getChar(inputdata,callback) {
|
||||||
|
$.ajax({
|
||||||
|
type: "POST",
|
||||||
|
contentType: "application/json; charset=utf-8",
|
||||||
|
url: "/api",
|
||||||
|
data: JSON.stringify(inputdata),
|
||||||
|
success: function (data) {
|
||||||
|
callback(data);
|
||||||
|
},
|
||||||
|
dataType: "json"
|
||||||
|
});
|
||||||
|
}
|
||||||
|
function getRes(sid,callback2) {
|
||||||
|
$.ajax({
|
||||||
|
type: "POST",
|
||||||
|
contentType: "application/json; charset=utf-8",
|
||||||
|
url: "/res",
|
||||||
|
data: JSON.stringify({"sid":sid}),
|
||||||
|
success: function (res) {
|
||||||
|
callback2(res);
|
||||||
|
},
|
||||||
|
dataType: "json"
|
||||||
|
});
|
||||||
|
}
|
||||||
|
$(function() {
|
||||||
|
var interval;
|
||||||
|
function callback(data){
|
||||||
|
interval = setInterval(function(){
|
||||||
|
if(data.sid == 0)
|
||||||
|
$('#form_output').val('backend service not found.');
|
||||||
|
else
|
||||||
|
getRes(data.sid, callback2);
|
||||||
|
}, 1000);
|
||||||
|
}
|
||||||
|
function callback2(res){
|
||||||
|
if(res.responds != '0'){
|
||||||
|
clearInterval(interval);
|
||||||
|
$('#form_output').val(res.responds);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
$( "#form_net_sample" ).submit(function( event ) {
|
||||||
|
event.preventDefault();
|
||||||
|
$('#form_output').val('load...');
|
||||||
|
var primetext = $('#form_input').val();
|
||||||
|
if(primetext.length <= 0){primetext = '';}
|
||||||
|
var temperature = $('#form_temperature').val();
|
||||||
|
if(temperature <= 0 || temperature > 10){temperature = '1';}
|
||||||
|
var seed = $('#form_seed').val();
|
||||||
|
if(seed.length <= 0){seed = '123';}
|
||||||
|
getChar({"primetext":primetext, "temperature":temperature, "seed":seed},callback);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
|
||||||
|
<form method="post" id="form_net_sample" class="form-group">
|
||||||
|
<label for="form_input">primetext<span class="description"></span></label>
|
||||||
|
<input name="form_input" type="text" class="form-control" id="form_input" placeholder="your text" value="">
|
||||||
|
<label for="form_temperature">temperature<span class="description">(0-1)</span></label>
|
||||||
|
<input name="form_temperature" type="text" class="form-control" id="form_temperature" placeholder="0.7" value="0.7">
|
||||||
|
<label for="form_seed">seed<span class="description"> (any number)</span></label>
|
||||||
|
<input name="form_seed" type="text" class="form-control" id="form_seed" placeholder="1" value="1">
|
||||||
|
|
||||||
|
<br/>
|
||||||
|
<button type="submit" class="btn btn-default">submit</button>
|
||||||
|
<br/><br/>
|
||||||
|
|
||||||
|
<label for="form_output">result</label>
|
||||||
|
<textarea disabled name="form_output" id="form_output" class="form-control" rows="15"></textarea>
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>
|
145
web_backend.lua
Normal file
145
web_backend.lua
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
require 'torch'
|
||||||
|
require 'nngraph'
|
||||||
|
require 'optim'
|
||||||
|
require 'lfs'
|
||||||
|
require 'nn'
|
||||||
|
|
||||||
|
require 'util.OneHot'
|
||||||
|
require 'util.misc'
|
||||||
|
JSON = (loadfile "util/JSON.lua")()
|
||||||
|
|
||||||
|
|
||||||
|
local redis = require 'redis'
|
||||||
|
local client = redis.connect('127.0.0.1', 6379)
|
||||||
|
local client2 = redis.connect('127.0.0.1', 6379)
|
||||||
|
local channels = {'cv_channel'}
|
||||||
|
local model_file = './onlie_model/model.t7'
|
||||||
|
local gpuid = 0
|
||||||
|
local seed = 123
|
||||||
|
|
||||||
|
-- check that cunn/cutorch are installed if user wants to use the GPU
|
||||||
|
if gpuid >= 0 then
|
||||||
|
local ok, cunn = pcall(require, 'cunn')
|
||||||
|
local ok2, cutorch = pcall(require, 'cutorch')
|
||||||
|
if not ok then print('package cunn not found!') end
|
||||||
|
if not ok2 then print('package cutorch not found!') end
|
||||||
|
if ok and ok2 then
|
||||||
|
print('using CUDA on GPU ' .. gpuid .. '...')
|
||||||
|
cutorch.setDevice(gpuid + 1) -- note +1 to make it 0 indexed! sigh lua
|
||||||
|
cutorch.manualSeed(seed)
|
||||||
|
else
|
||||||
|
print('Falling back on CPU mode')
|
||||||
|
gpuid = -1 -- overwrite user setting
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if not lfs.attributes(model_file, 'mode') then
|
||||||
|
print('Error: File ' .. model_file .. ' does not exist.')
|
||||||
|
end
|
||||||
|
checkpoint = torch.load(model_file)
|
||||||
|
protos = checkpoint.protos
|
||||||
|
protos.rnn:evaluate() -- put in eval mode so that dropout works properly
|
||||||
|
|
||||||
|
-- initialize the vocabulary (and its inverted version)
|
||||||
|
local vocab = checkpoint.vocab
|
||||||
|
local ivocab = {}
|
||||||
|
for c,i in pairs(vocab) do ivocab[i] = c end
|
||||||
|
|
||||||
|
-- parse characters from a string
|
||||||
|
function get_char(str)
|
||||||
|
local len = #str
|
||||||
|
local left = 0
|
||||||
|
local arr = {0, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc}
|
||||||
|
local unordered = {}
|
||||||
|
local start = 1
|
||||||
|
local wordLen = 0
|
||||||
|
while len ~= left do
|
||||||
|
local tmp = string.byte(str, start)
|
||||||
|
local i = #arr
|
||||||
|
while arr[i] do
|
||||||
|
if tmp >= arr[i] then
|
||||||
|
break
|
||||||
|
end
|
||||||
|
i = i - 1
|
||||||
|
end
|
||||||
|
wordLen = i + wordLen
|
||||||
|
local tmpString = string.sub(str, start, wordLen)
|
||||||
|
start = start + i
|
||||||
|
left = left + i
|
||||||
|
unordered[#unordered+1] = tmpString
|
||||||
|
end
|
||||||
|
return unordered
|
||||||
|
end
|
||||||
|
|
||||||
|
-- start listen
|
||||||
|
for msg in client:pubsub({subscribe = channels}) do
|
||||||
|
if msg.kind == 'subscribe' then
|
||||||
|
print('Subscribed to channel '..msg.channel)
|
||||||
|
elseif msg.kind == 'message' then
|
||||||
|
-- print('Received the following message from '..msg.channel.."\n "..msg.payload.."\n")
|
||||||
|
local req = JSON:decode(msg.payload)
|
||||||
|
local primetext = '|' .. req['text'] .. '| '
|
||||||
|
local session_id = req['sid']
|
||||||
|
local seed = req['seed']
|
||||||
|
local temperature = req['temp']
|
||||||
|
|
||||||
|
-- initialize the rnn state to all zeros
|
||||||
|
local current_state
|
||||||
|
local num_layers = checkpoint.opt.num_layers
|
||||||
|
current_state = {}
|
||||||
|
for L = 1,checkpoint.opt.num_layers do
|
||||||
|
-- c and h for all layers
|
||||||
|
local h_init = torch.zeros(1, checkpoint.opt.rnn_size)
|
||||||
|
if gpuid >= 0 then h_init = h_init:cuda() end
|
||||||
|
table.insert(current_state, h_init:clone())
|
||||||
|
table.insert(current_state, h_init:clone())
|
||||||
|
end
|
||||||
|
state_size = #current_state
|
||||||
|
|
||||||
|
-- use input to init state
|
||||||
|
torch.manualSeed(seed)
|
||||||
|
for i,c in ipairs(get_char(primetext)) do
|
||||||
|
prev_char = vocab[c]
|
||||||
|
if prev_char then
|
||||||
|
prev_char = torch.Tensor{vocab[c]}
|
||||||
|
io.write(ivocab[prev_char[1]])
|
||||||
|
if gpuid >= 0 then prev_char = prev_char:cuda() end
|
||||||
|
local lst = protos.rnn:forward{prev_char, unpack(current_state)}
|
||||||
|
-- lst is a list of [state1,state2,..stateN,output]. We want everything but last piece
|
||||||
|
current_state = {}
|
||||||
|
for i=1,state_size do table.insert(current_state, lst[i]) end
|
||||||
|
prediction = lst[#lst] -- last element holds the log probabilities
|
||||||
|
end
|
||||||
|
end
|
||||||
|
-- start sampling/argmaxing
|
||||||
|
result = ''
|
||||||
|
not_end = true
|
||||||
|
for i=1,1000 do
|
||||||
|
-- log probabilities from the previous timestep
|
||||||
|
-- make sure the output char is not UNKNOW
|
||||||
|
real_char = 'UNKNOW'
|
||||||
|
while(real_char == 'UNKNOW') do
|
||||||
|
torch.manualSeed(seed+1)
|
||||||
|
prediction:div(temperature) -- scale by temperature
|
||||||
|
local probs = torch.exp(prediction):squeeze()
|
||||||
|
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
|
||||||
|
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
|
||||||
|
real_char = ivocab[prev_char[1]]
|
||||||
|
end
|
||||||
|
|
||||||
|
-- forward the rnn for next character
|
||||||
|
local lst = protos.rnn:forward{prev_char, unpack(current_state)}
|
||||||
|
current_state = {}
|
||||||
|
for i=1,state_size do table.insert(current_state, lst[i]) end
|
||||||
|
prediction = lst[#lst] -- last element holds the log probabilities
|
||||||
|
result = result .. ivocab[prev_char[1]]
|
||||||
|
if string.find(result, '\n\n\n\n\n') then
|
||||||
|
not_end = false
|
||||||
|
break
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if not_end then result = result .. '……' end
|
||||||
|
-- client2:set(session_id, result)
|
||||||
|
client2:setex(session_id, 100, result)
|
||||||
|
end
|
||||||
|
end
|
51
web_server.py
Normal file
51
web_server.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
#!/usr/bin/python
|
||||||
|
#encoding=utf-8
|
||||||
|
import sys
|
||||||
|
reload(sys)
|
||||||
|
sys.setdefaultencoding('utf8')
|
||||||
|
|
||||||
|
from flask import Flask
|
||||||
|
from flask import jsonify,render_template,request,abort
|
||||||
|
import redis
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
channel_name = 'cv_channel'
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def index():
|
||||||
|
return render_template('main.html')
|
||||||
|
|
||||||
|
@app.route('/api', methods=['POST'])
|
||||||
|
def api():
|
||||||
|
if not request.json or not 'primetext' in request.json:
|
||||||
|
abort(400)
|
||||||
|
req = {}
|
||||||
|
req['text'] = request.json['primetext']
|
||||||
|
req['temp'] = request.json['temperature']
|
||||||
|
req['seed'] = request.json['seed']
|
||||||
|
m = hashlib.md5()
|
||||||
|
m.update(str(time.time()))
|
||||||
|
req['sid'] = m.hexdigest()
|
||||||
|
|
||||||
|
r = redis.StrictRedis(host='localhost', port=6379, db=0)
|
||||||
|
res = r.publish(channel_name, json.dumps(req))
|
||||||
|
print res
|
||||||
|
if res == 0:
|
||||||
|
req['sid'] = 0
|
||||||
|
|
||||||
|
return jsonify({'sid': req['sid']}), 200
|
||||||
|
|
||||||
|
@app.route('/res', methods=['POST'])
|
||||||
|
def res():
|
||||||
|
r = redis.StrictRedis(host='localhost', port=6379, db=0)
|
||||||
|
sid = request.json['sid']
|
||||||
|
responds = r.get(sid)
|
||||||
|
if responds is None:
|
||||||
|
responds = '0'
|
||||||
|
return jsonify({'responds': responds}), 200
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(host='0.0.0.0', port=8080, debug=True)
|
Loading…
Reference in New Issue
Block a user