fix convert bug, use double instead of float
This commit is contained in:
parent
94e0ee8f39
commit
e9ad92ebfd
@ -33,6 +33,6 @@ else
|
|||||||
end
|
end
|
||||||
|
|
||||||
checkpoint = torch.load(opt.load_model)
|
checkpoint = torch.load(opt.load_model)
|
||||||
checkpoint.protos.rnn:float()
|
checkpoint.protos.rnn:double()
|
||||||
checkpoint.protos.criterion:float()
|
checkpoint.protos.criterion:double()
|
||||||
torch.save(opt.save_file, checkpoint)
|
torch.save(opt.save_file, checkpoint)
|
||||||
|
@ -80,7 +80,7 @@ local num_layers = checkpoint.opt.num_layers
|
|||||||
current_state = {}
|
current_state = {}
|
||||||
for L = 1,checkpoint.opt.num_layers do
|
for L = 1,checkpoint.opt.num_layers do
|
||||||
-- c and h for all layers
|
-- c and h for all layers
|
||||||
local h_init = torch.zeros(1, checkpoint.opt.rnn_size):float()
|
local h_init = torch.zeros(1, checkpoint.opt.rnn_size):double()
|
||||||
if opt.gpuid >= 0 then h_init = h_init:cuda() end
|
if opt.gpuid >= 0 then h_init = h_init:cuda() end
|
||||||
table.insert(current_state, h_init:clone())
|
table.insert(current_state, h_init:clone())
|
||||||
table.insert(current_state, h_init:clone())
|
table.insert(current_state, h_init:clone())
|
||||||
|
Loading…
Reference in New Issue
Block a user