Source code for researchutils.chainer.links.connection.grad_clip_lstm

import functools
import operator

from researchutils.chainer.functions.activation import grad_clip_lstm

from chainer.backends import cuda
from chainer.links.connection import lstm
from chainer.functions.array import concat
from chainer.functions.array import split_axis
from chainer import variable


[docs]class GradClipLSTM(lstm.LSTM): """ Fully-connected LSTM layer with gradient clip before each gates. See: https://arxiv.org/abs/1308.0850 For detail description of LSTM layer itself, check original LSTM implementation of chainer. """ def __init__(self, in_size, out_size=None, lateral_init=None, upward_init=None, bias_init=None, forget_bias_init=None, clip_min=None, clip_max=None): self.clip_min = clip_min self.clip_max = clip_max super(GradClipLSTM, self).__init__( in_size, out_size, lateral_init, upward_init, bias_init, forget_bias_init) def __call__(self, x): """Updates the internal state and returns the LSTM outputs. Args: x (~chainer.Variable): A new batch from the input sequence. Returns: ~chainer.Variable: Outputs of updated LSTM units. """ if self.upward.W.data is None: with cuda.get_device_from_id(self._device_id): in_size = functools.reduce(operator.mul, x.shape[1:], 1) self.upward._initialize_params(in_size) self._initialize_params() batch = x.shape[0] lstm_in = self.upward(x) h_rest = None if self.h is not None: h_size = self.h.shape[0] if batch == 0: h_rest = self.h elif h_size < batch: msg = ('The batch size of x must be equal to or less than' 'the size of the previous state h.') raise TypeError(msg) elif h_size > batch: h_update, h_rest = split_axis.split_axis( self.h, [batch], axis=0) lstm_in += self.lateral(h_update) else: lstm_in += self.lateral(self.h) if self.c is None: xp = self.xp with cuda.get_device_from_id(self._device_id): self.c = variable.Variable( xp.zeros((batch, self.state_size), dtype=x.dtype)) self.c, y = grad_clip_lstm.grad_clip_lstm( self.c, lstm_in, clip_min=self.clip_min, clip_max=self.clip_max) if h_rest is None: self.h = y elif len(y.data) == 0: self.h = h_rest else: self.h = concat.concat([y, h_rest], axis=0) return y