def test_sigmoid_grad(N=None): from activations import Sigmoid N = np.inf if N is None else N mine = Sigmoid() gold = torch_gradient_generator(torch.sigmoid) i = 0 while i < N: n_ex = np.random.randint(1, 100) n_dims = np.random.randint(1, 100) z = random_tensor((n_ex, n_dims)) assert_almost_equal(mine.grad(z), gold(z)) print("PASSED") i += 1
class WavenetResidualModule(ModuleBase): def __init__( self, ch_residual, ch_dilation, dilation, kernel_width, optimizer=None, init="glorot_uniform", ): """ A WaveNet-like residual block with causal dilated convolutions. *Skip path in* >-------------------------------------------> + --------> *Skip path out* Causal |--> Tanh --| | *Main |--> Dilated Conv1D -| * --> 1x1 Conv1D --| path >--| |--> Sigm --| | in* |-------------------------------------------------> + --------> *Main path out* *Residual path* On the final block, the output of the skip path is further processed to produce the network predictions. See van den Oord et al. (2016) at https://arxiv.org/pdf/1609.03499.pdf for further details. Parameters ---------- ch_residual : int The number of output channels for the 1x1 Conv1D layer in the main path ch_dilation : int The number of output channels for the causal dilated Conv1D layer in the main path dilation : int The dilation rate for the causal dilated Conv1D layer in the main path kernel_width : int The width of the causal dilated Conv1D kernel in the main path init : str (default: 'glorot_uniform') The weight initialization strategy. Valid entries are {'glorot_normal', 'glorot_uniform', 'he_normal', 'he_uniform'} optimizer : str or `OptimizerBase` instance (default: None) The optimization strategy to use when performing gradient updates within the `update` method. If `None`, use the `SGD` optimizer with default parameters. """ super().__init__() self.init = init self.dilation = dilation self.optimizer = optimizer self.ch_residual = ch_residual self.ch_dilation = ch_dilation self.kernel_width = kernel_width self._init_params() def _init_params(self): self._dv = {} self.conv_dilation = Conv1D( stride=1, pad="causal", init=self.init, kernel_width=2, dilation=self.dilation, out_ch=self.ch_dilation, optimizer=self.optimizer, act_fn=Affine(slope=1, intercept=0), ) self.tanh = Tanh() self.sigm = Sigmoid() self.multiply_gate = Multiply(act_fn=Affine(slope=1, intercept=0)) self.conv_1x1 = Conv1D( stride=1, pad="same", dilation=0, init=self.init, kernel_width=1, out_ch=self.ch_residual, optimizer=self.optimizer, act_fn=Affine(slope=1, intercept=0), ) self.add_residual = Add(act_fn=Affine(slope=1, intercept=0)) self.add_skip = Add(act_fn=Affine(slope=1, intercept=0)) @property def parameters(self): return { "components": { "conv_1x1": self.conv_1x1.parameters, "add_skip": self.add_skip.parameters, "add_residual": self.add_residual.parameters, "conv_dilation": self.conv_dilation.parameters, "multiply_gate": self.multiply_gate.parameters, } } @property def hyperparameters(self): return { "layer": "WavenetResidualModule", "init": self.init, "dilation": self.dilation, "optimizer": self.optimizer, "ch_residual": self.ch_residual, "ch_dilation": self.ch_dilation, "kernel_width": self.kernel_width, "component_ids": [ "conv_1x1", "add_skip", "add_residual", "conv_dilation", "multiply_gate", ], "components": { "conv_1x1": self.conv_1x1.hyperparameters, "add_skip": self.add_skip.hyperparameters, "add_residual": self.add_residual.hyperparameters, "conv_dilation": self.conv_dilation.hyperparameters, "multiply_gate": self.multiply_gate.hyperparameters, }, } @property def derived_variables(self): dv = { "conv_1x1_out": None, "conv_dilation_out": None, "multiply_gate_out": None, "components": { "conv_1x1": self.conv_1x1.derived_variables, "add_skip": self.add_skip.derived_variables, "add_residual": self.add_residual.derived_variables, "conv_dilation": self.conv_dilation.derived_variables, "multiply_gate": self.multiply_gate.derived_variables, }, } dv.update(self._dv) return dv @property def gradients(self): return { "components": { "conv_1x1": self.conv_1x1.gradients, "add_skip": self.add_skip.gradients, "add_residual": self.add_residual.gradients, "conv_dilation": self.conv_dilation.gradients, "multiply_gate": self.multiply_gate.gradients, } } def forward(self, X_main, X_skip=None): self.X_main, self.X_skip = X_main, X_skip conv_dilation_out = self.conv_dilation.forward(X_main) tanh_gate = self.tanh.fn(conv_dilation_out) sigm_gate = self.sigm.fn(conv_dilation_out) multiply_gate_out = self.multiply_gate.forward([tanh_gate, sigm_gate]) conv_1x1_out = self.conv_1x1.forward(multiply_gate_out) # if this is the first wavenet block, initialize the "previous" skip # connection sum to 0 self.X_skip = np.zeros_like(conv_1x1_out) if X_skip is None else X_skip Y_skip = self.add_skip.forward([X_skip, conv_1x1_out]) Y_main = self.add_residual.forward([X_main, conv_1x1_out]) self._dv["tanh_out"] = tanh_gate self._dv["sigm_out"] = sigm_gate self._dv["conv_dilation_out"] = conv_dilation_out self._dv["multiply_gate_out"] = multiply_gate_out self._dv["conv_1x1_out"] = conv_1x1_out return Y_main, Y_skip def backward(self, dY_skip, dY_main=None): dX_skip, dConv_1x1_out = self.add_skip.backward(dY_skip) # if this is the last wavenet block, dY_main will be None. if not, # calculate the error contribution from dY_main and add it to the # contribution from the skip path dX_main = np.zeros_like(self.X_main) if dY_main is not None: dX_main, dConv_1x1_main = self.add_residual.backward(dY_main) dConv_1x1_out += dConv_1x1_main dMultiply_out = self.conv_1x1.backward(dConv_1x1_out) dTanh_out, dSigm_out = self.multiply_gate.backward(dMultiply_out) conv_dilation_out = self.derived_variables["conv_dilation_out"] dTanh_in = dTanh_out * self.tanh.grad(conv_dilation_out) dSigm_in = dSigm_out * self.sigm.grad(conv_dilation_out) dDilation_out = dTanh_in + dSigm_in conv_back = self.conv_dilation.backward(dDilation_out) dX_main += conv_back self._dv["dLdTanh"] = dTanh_out self._dv["dLdSigmoid"] = dSigm_out self._dv["dLdConv_1x1"] = dConv_1x1_out self._dv["dLdMultiply"] = dMultiply_out self._dv["dLdConv_dilation"] = dDilation_out return dX_main, dX_skip