21 lines
670 B
Lua
21 lines
670 B
Lua
|
|
local OneHot, parent = torch.class('OneHot', 'nn.Module')
|
|
|
|
function OneHot:__init(outputSize)
|
|
parent.__init(self)
|
|
self.outputSize = outputSize
|
|
-- We'll construct one-hot encodings by using the index method to
|
|
-- reshuffle the rows of an identity matrix. To avoid recreating
|
|
-- it every iteration we'll cache it.
|
|
self._eye = torch.eye(outputSize)
|
|
end
|
|
|
|
function OneHot:updateOutput(input)
|
|
self.output:resize(input:size(1), self.outputSize):zero()
|
|
if self._eye == nil then self._eye = torch.eye(self.outputSize) end
|
|
self._eye = self._eye:float()
|
|
local longInput = input:long()
|
|
self.output:copy(self._eye:index(1, longInput))
|
|
return self.output
|
|
end
|