diff --git a/Readme.md b/Readme.md index 254eb3d..72d8b52 100644 --- a/Readme.md +++ b/Readme.md @@ -35,6 +35,13 @@ $ nohup th web_backend.lua & $ nohup python web_server.py & ``` +## Model conversion between cpu and gpu +I add a script to convert a model file trained by gpu to cpu model. +You can try it as follow: +```bash +$ th convert.lua gpu_model cpu_model +``` + ----------------------------------------------- ## Karpathy's raw Readme please follow this to setup your experiment. diff --git a/convert.lua b/convert.lua new file mode 100644 index 0000000..b7b0682 --- /dev/null +++ b/convert.lua @@ -0,0 +1,38 @@ +require 'torch' +require 'nngraph' +require 'optim' +require 'lfs' +require 'nn' + +require 'util.OneHot' +require 'util.misc' + +cmd = torch.CmdLine() +cmd:text() +cmd:text('convert a gpu model to cpu one') +cmd:text() +cmd:text('Options') + +cmd:argument('-load_model','model to convert') +cmd:argument('-save_file','the file path to save the converted model') +cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') +cmd:text() + +-- parse input params +opt = cmd:parse(arg) +local ok, cunn = pcall(require, 'cunn') +local ok2, cutorch = pcall(require, 'cutorch') +if not ok then gprint('package cunn not found!') end +if not ok2 then gprint('package cutorch not found!') end +if ok and ok2 then + print('using CUDA on GPU ' .. opt.gpuid .. '...') + cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua +else + print('No gpu found. Convert fail.') + os.exit() +end + +checkpoint = torch.load(opt.load_model) +checkpoint.protos.rnn:float() +checkpoint.protos.criterion:float() +torch.save(opt.save_file, checkpoint) diff --git a/sample.lua b/sample.lua index 8504017..fb58685 100644 --- a/sample.lua +++ b/sample.lua @@ -80,7 +80,7 @@ local num_layers = checkpoint.opt.num_layers current_state = {} for L = 1,checkpoint.opt.num_layers do -- c and h for all layers - local h_init = torch.zeros(1, checkpoint.opt.rnn_size) + local h_init = torch.zeros(1, checkpoint.opt.rnn_size):float() if opt.gpuid >= 0 then h_init = h_init:cuda() end table.insert(current_state, h_init:clone()) table.insert(current_state, h_init:clone()) diff --git a/web_backend.lua b/web_backend.lua index 7eaf133..8483e30 100644 --- a/web_backend.lua +++ b/web_backend.lua @@ -89,7 +89,7 @@ for msg in client:pubsub({subscribe = channels}) do current_state = {} for L = 1,checkpoint.opt.num_layers do -- c and h for all layers - local h_init = torch.zeros(1, checkpoint.opt.rnn_size) + local h_init = torch.zeros(1, checkpoint.opt.rnn_size):float() if gpuid >= 0 then h_init = h_init:cuda() end table.insert(current_state, h_init:clone()) table.insert(current_state, h_init:clone())