class wcLinear(nn.Linear): """ custom Linear layers for quantization """ def __init__(self, in_features, out_features, bias=True, rate=0.): super(wcLinear, self).__init__(in_features=in_features, out_features=out_features, bias=bias) self.binary_weight = Parameter(torch.ones(self.weight.data.size(1))) self.float_weight = Parameter(torch.ones(self.weight.data.size(1))) self.register_buffer('rate', torch.ones(1).fill_(rate)) def compute_grad(self): self.float_weight.grad = Variable(self.binary_weight.grad.data) # set binary_weight_grad to zero is very very important self.binary_weight.grad = None def forward(self, input): if self.train: self.float_weight.clamp(min=0) self.binary_weight.data.copy_( self.float_weight.data.ge(self.rate[0]).float()) # get new weight new_weight = self.binary_weight.unsqueeze(0).expand_as( self.weight) * self.weight return F.linear(input, new_weight, self.bias)
class wcConv2d(nn.Conv2d): """ custom convolutional layers for quantization """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, rate=0.): super(wcConv2d, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) self.binary_weight = Parameter(torch.ones(self.weight.data.size(1))) self.float_weight = Parameter(torch.ones(self.weight.data.size(1))) self.register_buffer('rate', torch.ones(1).fill_(rate)) def compute_grad(self): self.float_weight.grad = Variable(self.binary_weight.grad.data) # set binary_weight_grad to zero is very very important self.binary_weight.grad = None def forward(self, input): if self.train: self.float_weight.clamp(min=0) self.binary_weight.data.copy_( self.float_weight.data.ge(self.rate[0]).float()) new_weight = self.binary_weight.unsqueeze(0).unsqueeze(2).unsqueeze( 3).expand_as(self.weight) * self.weight return F.conv2d(input, new_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class DOnePoleCell(Module): def __init__(self, a1=0.5, b0=1.0, b1=0.0): super(DOnePoleCell, self).__init__() self.b0 = Parameter(FloatTensor([b0])) self.b1 = Parameter(FloatTensor([b1])) self.a1 = Parameter(FloatTensor([a1])) def init_states(self, size): state = torch.zeros(size).to(self.a1.device) return state def forward(self, input, state): self.a1.data = self.a1.clamp(-1, 1) output = self.b0 * input + state state = self.b1 * input + self.a1 * output return output, state
class CartesianAdj(Module): """Concatenates Cartesian spatial relations based on the position :math:`P \in \mathbb{R}^{N x D}` of graph nodes to the graph's edge attributes.""" def __init__(self, r=None, trainable=False): super(CartesianAdj, self).__init__() if r is not None: r = torch.FloatTensor([r]).cuda() if trainable and r is not None: self.r = Parameter(r) else: self.r = r def __call__(self, data): row, col = data.index # Compute Cartesian pseudo-coordinates. weight = data.pos[col] - data.pos[row] max = weight.abs().max() if self.r is None else self.r.clamp( min=0.0001) if self.r is not None: weight = weight * (1 / max) factor = weight.abs().max(1)[0].clamp(min=1) weight = weight / factor.unsqueeze(1) weight = weight / 2 else: weight = weight * (1 / (2 * max)) weight = weight + 0.5 if data.weight is None: data.weight = weight else: data.weight = torch.cat([weight, data.weight.unsqueeze(1)], dim=1) return data
class BaseRNNCell(nn.Module): def __init__(self, input_size, hidden_size, bias=False, nonlinearity="tanh", hidden_min_abs=0, hidden_max_abs=None, hidden_init=None, recurrent_init=None, gradient_clip=5): super(BaseRNNCell, self).__init__() self.hidden_max_abs = hidden_max_abs self.hidden_min_abs = hidden_min_abs self.input_size = input_size self.hidden_size = hidden_size self.bias = bias self.nonlinearity = nonlinearity self.hidden_init = hidden_init self.recurrent_init = recurrent_init if self.nonlinearity == "tanh": self.activation = F.tanh elif self.nonlinearity == "relu": self.activation = F.relu elif self.nonlinearity == "sigmoid": self.activation = F.sigmoid elif self.nonlinearity == "log": self.activation = torch.log elif self.nonlinearity == "sin": self.activation = torch.sin else: raise RuntimeError("Unknown nonlinearity: {}".format( self.nonlinearity)) self.weight_ih = Parameter(torch.eye(hidden_size, input_size)) self.weight_hh = Parameter(torch.Tensor(hidden_size, 20).uniform_()) self.weight_hh1 = Parameter(torch.eye(input_size, hidden_size)) if bias: self.bias_ih = Parameter(torch.randn(hidden_size)) else: self.register_parameter('bias_ih', None) # self.reset_parameters() def reset_parameters(self): stdv = 1.0 / math.sqrt(self.hidden_size) for weight in self.parameters(): weight.data.uniform_(-stdv, stdv) # def reset_parameters(self): # for name, weight in self.named_parameters(): # if "bias" in name: # weight.data.zero_() # elif "weight_hh" in name: # if self.recurrent_init is None: # nn.init.constant_(weight, 1) # else: # self.recurrent_init(weight) # elif "weight_ih" in name: # if self.hidden_init is None: # nn.init.normal_(weight, 0, 0.01) # else: # self.hidden_init(weight) # else: # weight.data.normal_(0, 0.01) # # weight.data.uniform_(-stdv, stdv) # self.check_bounds() def check_bounds(self): if self.hidden_min_abs: abs_kernel = torch.abs( self.weight_hh.data).clamp_(min=self.hidden_min_abs) self.weight_hh.data = self.weight_hh.mul( torch.sign(self.weight_hh.data), abs_kernel) if self.hidden_max_abs: self.weight_hh.data = self.weight_hh.clamp( max=self.hidden_max_abs, min=-self.hidden_max_abs) def forward(self, input, hx): # x = F.linear(input, self.weight_ih, self.bias_ih) + torch.matmul(hx, self.weight_hh.matmul(self.weight_hh1)) # return self.talor(x) return self.activation( F.linear(input, self.weight_ih, self.bias_ih) + torch.matmul(hx, self.weight_ih.matmul(self.weight_hh1))) def talor(self, x): return (x - 1) - (x - 1) * (x - 1) / 2 + (x - 1) * (x - 1) * (x - 1) / 3
class IndRNNCell(nn.Module): r"""An IndRNN cell with tanh or ReLU non-linearity. .. math:: h' = \tanh(w_{ih} * x + b_{ih} + w_{hh} (*) h) With (*) being element-wise vector multiplication. If nonlinearity='relu', then ReLU is used in place of tanh. Args: input_size: The number of expected features in the input x hidden_size: The number of features in the hidden state h bias: If ``False``, then the layer does not use bias weights b_ih and b_hh. Default: ``True`` nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'relu' hidden_min_abs: Minimal absolute inital value for hidden weights. Default: 0 hidden_max_abs: Maximal absolute inital value for hidden weights. Default: None Inputs: input, hidden - **input** (batch, input_size): tensor containing input features - **hidden** (batch, hidden_size): tensor containing the initial hidden state for each element in the batch. Outputs: h' - **h'** (batch, hidden_size): tensor containing the next hidden state for each element in the batch Attributes: weight_ih: the learnable input-hidden weights, of shape `(input_size x hidden_size)` weight_hh: the learnable hidden-hidden weights, of shape `(hidden_size)` bias_ih: the learnable input-hidden bias, of shape `(hidden_size)` Examples:: >>> rnn = nn.IndRNNCell(10, 20) >>> input = Variable(torch.randn(6, 3, 10)) >>> hx = Variable(torch.randn(3, 20)) >>> output = [] >>> for i in range(6): ... hx = rnn(input[i], hx) ... output.append(hx) """ def __init__(self, input_size, hidden_size, bias=True, nonlinearity="relu", hidden_min_abs=0, hidden_max_abs=None, hidden_init=None, recurrent_init=None, gradient_clip=None): super(IndRNNCell, self).__init__() self.hidden_max_abs = hidden_max_abs self.hidden_min_abs = hidden_min_abs self.input_size = input_size self.hidden_size = hidden_size self.bias = bias self.nonlinearity = nonlinearity self.hidden_init = hidden_init self.recurrent_init = recurrent_init if self.nonlinearity == "tanh": self.activation = F.tanh elif self.nonlinearity == "relu": self.activation = F.relu else: raise RuntimeError("Unknown nonlinearity: {}".format( self.nonlinearity)) self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size)) self.weight_hh = Parameter(torch.Tensor(hidden_size)) if bias: self.bias_ih = Parameter(torch.Tensor(hidden_size)) else: self.register_parameter('bias_ih', None) if gradient_clip: if isinstance(gradient_clip, tuple): min_g, max_g = gradient_clip else: max_g = gradient_clip min_g = -max_g self.weight_ih.register_hook( lambda x: x.clamp(min=min_g, max=max_g)) self.weight_hh.register_hook( lambda x: x.clamp(min=min_g, max=max_g)) if bias: self.bias_ih.register_hook( lambda x: x.clamp(min=min_g, max=max_g)) self.reset_parameters() def reset_parameters(self): for name, weight in self.named_parameters(): if "bias" in name: weight.data.zero_() elif "weight_hh" in name: if self.recurrent_init is None: nn.init.constant_(weight, 1) else: self.recurrent_init(weight) elif "weight_ih" in name: if self.hidden_init is None: nn.init.normal_(weight, 0, 0.01) else: self.hidden_init(weight) else: weight.data.normal_(0, 0.01) # weight.data.uniform_(-stdv, stdv) self.check_bounds() def check_bounds(self): if self.hidden_min_abs: abs_kernel = torch.abs( self.weight_hh.data).clamp_(min=self.hidden_min_abs) self.weight_hh.data = self.weight_hh.mul( torch.sign(self.weight_hh.data), abs_kernel) if self.hidden_max_abs: self.weight_hh.data = self.weight_hh.clamp( max=self.hidden_max_abs, min=-self.hidden_max_abs) def forward(self, input, hx): return self.activation( F.linear(input, self.weight_ih, self.bias_ih) + F.mul(self.weight_hh, hx))
class IndRNNCell(nn.Module): def __init__(self, input_size, hidden_size, bias=True, activation="relu", recurrent_min_abs=None, recurrent_max_abs=None, hidden_initializer=None, recurrent_initializer=None, gradient_clip_min=None, gradient_clip_max=None): super(IndRNNCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.bias = bias self.weight_ih = Parameter( torch.Tensor(self.hidden_size, self.input_size)) self.weight_hh = Parameter(torch.Tensor(hidden_size)) if bias: self.bias_ih = nn.Parameter(torch.Tensor(hidden_size)) else: self.register_parameter('bias_ih', None) if activation == "relu": self.activation = F.relu elif activation == "tanh": self.activation = F.tanh else: warnings.warn( "IndRNN supports only ReLu and tanh activations. Fallingback to ReLU " ) self.activation = F.relu self.recurrent_min_abs = recurrent_min_abs self.recurrent_max_abs = recurrent_max_abs self.hidden_initializer = hidden_initializer self.recurrent_initializer = recurrent_initializer # Gradient Clippnig to prevent Gradient Explosion and over fitting if not gradient_clip_max is None: self.gradient_clip_min = -gradient_clip_max self.gradient_clip_max = gradient_clip_max if not gradient_clip_min is None: self.gradient_clip_min = gradient_clip_min # register_hook will record the change to the parameter made # into the grad and this will be used during gradient descent self.weight_ih.register_hook(lambda x: x.clamp_( min=gradient_clip_min, max=gradient_clip_max)) self.weight_hh.register_hook(lambda x: x.clamp_( min=gradient_clip_min, max=gradient_clip_max)) if self.bias: self.bias_ih.register_hook(lambda x: x.clamp_( min=gradient_clip_min, max=gradient_clip_max)) # Initialize all parametere of the model for name, weight in self.named_parameters(): if "bias" in name: # self.add_variable("bias", shape=[self._num_units], initializer=init_ops.zeros_initializer(dtype=self.dtype)) weight.data.zero_() elif "weight_ih" in name: # self._input_initializer = init_ops.random_normal_initializer(mean=0.0, stddev=0.001) if self.hidden_initializer is None: nn.init.normal_(weight, 0, 0.01) else: self.hidden_initializer(weight) elif "weight_hh" in name: # self._recurrent_initializer = init_ops.constant_initializer(1.) if self.recurrent_initializer is None: nn.init.constant_(weight, 1) else: self.recurrent_initializer(weight) else: weight.data.normal_(0, 0.01) self.clip_recurrent_weights() def clip_recurrent_weights(self): # Clip the absolute values of the recurrent weights to the specified minimum r""" Code from https://github.com/batzner/indrnn/blob/master/ind_rnn_cell.py # Clip the absolute values of the recurrent weights to the specified minimum if self._recurrent_min_abs: abs_kernel = math_ops.abs(self._recurrent_kernel) min_abs_kernel = math_ops.maximum(abs_kernel, self._recurrent_min_abs) self._recurrent_kernel = math_ops.multiply( math_ops.sign(self._recurrent_kernel), min_abs_kernel ) # Clip the absolute values of the recurrent weights to the specified maximum if self._recurrent_max_abs: self._recurrent_kernel = clip_ops.clip_by_value(self._recurrent_kernel, -self._recurrent_max_abs, self._recurrent_max_abs) """ if self.recurrent_min_abs: abs_kernel = torch.abs( self.weight_hh.data).clamp_(min=self.recurrent_min_abs) self.weight_hh.data = abs_kernel.mm(torch.sign( self.weight_hh.data)) if self.recurrent_max_abs: self.weight_hh.data = self.weight_hh.clamp( max=self.recurrent_max_abs, min=-self.recurrent_max_abs) # if self.recurrent_min_abs: # # abs_kernel = torch.abs(self.weight_hh.data).clamp_(min=self.recurrent_min_abs) # # self.weight_hh.data = self.weight_hh.mul(torch.sign(self.weight_hh.data), abs_kernel) # abs_kernel = torch.abs(self.weight_hh.data).clamp_(min=self.recurrent_min_abs) # self.weight_hh.data = self.weight_hh.mul(torch.sign(self.weight_hh.data), abs_kernel) # # # Clip the absolute values of the recurrent weights to the specified maximum # if self.recurrent_max_abs: # self.weight_hh.data = self.weight_hh.clamp(min=-self._recurrent_max_abs, # max=self._recurrent_max_abs) # Pendnng: Implement code for dropouts # -------- def forward(self, input, hx=None): # out = tanh(w_{ih} * x + b_{ih} + w_{hh} (*) h) # (*) Hammard Product return self.activation( F.linear(input, self.weight_ih, self.bias_ih) + F.mul(self.weight_hh, hx))
class IndRNNCell(nn.Module): """ IndRNN Cell computes: $$h_t = \sigma(w_{ih} x_t + b_{ih} + w_{hh} (*) h_{(t-1)})$$ \sigma is sigmoid or relu hyper-params: 1. hidden_size 2. input_size 3. bias: true or false 4. act: the nonlinearity function ("tanh", "relu", "sigmoid") 5. hidden_min_abs & hidden_max_abs 6. reccurent_only: only computes the reccurent part for faster computation. 7. init: how to initialize the params. Default norm for N(0,1/\sqrt(size)); constant; uniform; orth 8. gradient_clip: `(min,max)` or `bound` inputs: 1. Input: (batch, input_size) 2. Hidden: (batch, hidden_size) batch first by default output: 1. output: (batch, hidden_size) 1. hidden state: (batch, hidden_size) params: 1. weight_ih: (hidden_size,input_size) 2. weight_hh: (1,hidden_size) 3. bias_ih: (1,hidden_size) or None usage: >>> cell = IndRNNCell(100,128) >>> Input = torch.randn(32,100) >>> Hidden = torch.randn(32,128) >>> _, h = cell(Input, Hidden) """ def __init__(self, input_size, hidden_size, bias=True, act="relu", hidden_min_abs=0, hidden_max_abs=2, reccurent_only=False, gradient_clip=None, init_ih="norm", input_weight_initializer=None, recurrent_weight_initializer=None, name="Default", debug=False): super(IndRNNCell, self).__init__() self.hidden_max_abs = hidden_max_abs self.hidden_min_abs = hidden_min_abs self.input_size = input_size self.hidden_size = hidden_size self.bias = bias self.act = act self.reccurent_only = reccurent_only self.init_ih = init_ih self.input_weight_initializer = input_weight_initializer self.recurrent_weight_initializer = recurrent_weight_initializer self.name = name self.debug = debug if self.act is None: self.activation = F.tanh elif self.act == "relu": self.activation = F.relu elif self.act == "sigmoid": self.activation = F.sigmoid elif self.act == "tanh": self.activation = None else: raise RuntimeError(f"Unknown activation type: {self.nonlinearity}") if not self.reccurent_only: self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size)) else: self.register_parameter('weight_ih', None) self.weight_hh = Parameter(torch.Tensor(1, hidden_size)) if bias: self.bias_ih = Parameter(torch.Tensor(1, hidden_size)) else: self.register_parameter('bias_ih', None) if gradient_clip: if isinstance(gradient_clip, tuple): assert len(gradient_clip) == 2 min_g, max_g = gradient_clip else: max_g = gradient_clip min_g = -max_g if not self.reccurent_only: self.weight_ih.register_hook( lambda x: x.clamp(min=min_g, max=max_g)) self.weight_hh.register_hook( lambda x: x.clamp(min=min_g, max=max_g)) if bias: self.bias_ih.register_hook( lambda x: x.clamp(min=min_g, max=max_g)) # debug # if self.debug: # pdb.set_trace() self.reset_parameters() def reset_parameters(self): for name, weight in self.named_parameters(): if "bias" in name: weight.data.zero_() elif "weight" in name: if self.input_weight_initializer and "weight_ih" in name: self.input_weight_initializer(weight) elif self.recurrent_weight_initializer and "weight_hh" in name: self.recurrent_weight_initializer(weight) elif "constant" in self.init_ih: nn.init.constant_(weight, 1.0) else: weight.data.normal_(0, 0.01) self.clip_weight() def clip_weight(self): if self.hidden_min_abs: abs_kernel = torch.abs( self.weight_hh.data).clamp(min=self.hidden_min_abs) self.weight_hh.data = torch.sign(self.weight_hh.data) * abs_kernel if self.hidden_max_abs: self.weight_hh.data = self.weight_hh.clamp( min=-self.hidden_max_abs, max=self.hidden_max_abs) self.weight_hh.data.detach_() def forward(self, Input, Hidden): if not self.reccurent_only: h = F.linear(Input, self.weight_ih) + self.weight_hh * Hidden if self.bias: h += self.bias_ih else: h = Input + self.weight_hh * Hidden if self.activation: h = self.activation(h) return h, h
class SBP(Gate): def __init__(self, num_gates, min_log=-20.0, max_log=0.0, thres=1.0, kl_scale=1.0): super(SBP, self).__init__(num_gates) self.min_log = min_log self.max_log = max_log self.thres = thres self.kl_scale = kl_scale self.mu = Parameter(torch.zeros(num_gates)) self.log_sigma = Parameter(-5 * torch.ones(num_gates)) def _mean_truncated_log_normal(self): a, b = self.min_log, self.max_log mu = self.mu.clamp(-20.0, 5.0) log_sigma = self.log_sigma.clamp(-20.0, 5.0) sigma = log_sigma.exp() alpha = (a - mu) / sigma beta = (b - mu) / sigma z = phi(beta) - phi(alpha) mean = erfcx( (sigma - beta) / math.sqrt(2.0)) * torch.exp(b - beta * beta / 2) mean = mean - erfcx((sigma - alpha) / math.sqrt(2.0)) * torch.exp(a - alpha * alpha / 2) mean = mean / (2 * z) return mean def _snr_truncated_log_normal(self): a, b = self.min_log, self.max_log mu = self.mu.clamp(-20.0, 5.0) log_sigma = self.log_sigma.clamp(-20.0, 5.0) sigma = log_sigma.exp() alpha = (a - mu) / sigma beta = (b - mu) / sigma z = phi(beta) - phi(alpha) ratio = erfcx((sigma - beta) / math.sqrt(2.0)) * torch.exp((b - mu) - beta**2 / 2.0) ratio = ratio - erfcx((sigma - alpha) / math.sqrt(2.0)) * torch.exp( (a - mu) - alpha**2 / 2.0) denominator = 2 * z * erfcx( (2.0 * sigma - beta) / math.sqrt(2.0)) * torch.exp(2.0 * (b - mu) - beta**2 / 2.0) denominator = denominator - 2*z*erfcx((2.0*sigma-alpha)/math.sqrt(2.0))\ *torch.exp(2.0*(a-mu)-alpha**2/2.0) denominator = denominator - ratio**2 ratio = ratio / torch.sqrt(denominator) return ratio def _sample_truncated_normal(self): a, b = self.min_log, self.max_log mu = self.mu.clamp(-20.0, 5.0) log_sigma = self.log_sigma.clamp(-20.0, 5.0) sigma = torch.exp(log_sigma) alpha = (a - mu) / sigma beta = (b - mu) / sigma u = torch.rand(self.num_gates) if torch.cuda.is_available(): u = u.cuda() gamma = phi(alpha) + u * (phi(beta) - phi(alpha)) return (phi_inv(gamma.clamp(1e-5, 1 - 1e-5)) * sigma + mu).clamp( a, b).exp() def get_mask(self): snr = self._snr_truncated_log_normal() return (snr > self.thres).float() def get_weight(self, x): if self.training: z = self._sample_truncated_normal() else: Etheta = self._mean_truncated_log_normal() mask = self.get_mask() z = Etheta * mask return z def get_reg(self, base): a, b = self.min_log, self.max_log mu = self.mu.clamp(-20.0, 5.0) log_sigma = self.log_sigma.clamp(-20.0, 5.0) sigma = log_sigma.exp() alpha = (a - mu) / sigma beta = (b - mu) / sigma z = phi(beta) - phi(alpha) def pdf(x): return torch.exp(-x * x / 2.0) / math.sqrt(2.0 * math.pi) kld = -log_sigma - torch.log(z) - (alpha * pdf(alpha) - beta * pdf(beta)) / (2.0 * z) kld += math.log(self.max_log - self.min_log) - math.log(2.0 * math.pi * math.e) / 2.0 kld = self.kl_scale * kld.sum() return kld