Beispiel #1
0
 def fn(x: bool, y: torch.Tensor):
     if x:
         y.add_(8)
         a = y + 3
     else:
         a = y + 3
     return a
Beispiel #2
0
def _sp_double_backward_update(pos_out: Tensor,
                               neg_out: Tensor,
                               param: Parameter,
                               gamma: float,
                               l1_reg: float,
                               l2_reg: float,
                               pos: Tensor = None):
    param.grad = None
    # first backward
    neg_out.backward()
    neg = param.grad.relu_().add_(eps)

    if pos is None:
        param.grad = None
        pos_out.backward()
        pos = param.grad.relu_().add_(eps)

    if l1_reg > 0:
        pos.add_(l1_reg)
    if l2_reg > 0:
        pos = pos.add(param.data, alpha=l2_reg)
    multiplier = neg.div_(pos)
    if gamma != 1:
        multiplier.pow_(gamma)
    param.data.mul_(multiplier)
Beispiel #3
0
class Linear(Module):
    """
    Layer module: Fully connected layer defined by input dimensions and output_dimensions
    
    Outputs:
    forward  :  FloatTensor of size m (m: number of units)
    backward :  FloatTensor of size m (m: number of units)
    """
    def __init__(self, input_dim, output_dim, epsilon=1):
        super().__init__()
        torch.manual_seed(123)
        self.weight = Tensor(output_dim, input_dim).normal_(mean=0, std=epsilon)
        self.bias = Tensor(output_dim).normal_(0, epsilon)
        self.x = 0
        self.dl_dw = Tensor(self.weight.size())
        self.dl_db = Tensor(self.bias.size())
         
    def forward(self, input):
        self.x = input
        return self.weight.mv(self.x) + self.bias
    
    def backward(self, grdwrtoutput):
        self.dl_dw.add_(grdwrtoutput.view(-1,1).mm(self.x.view(1,-1)))
        self.dl_db.add_(grdwrtoutput)
        return self.weight.t().mv(grdwrtoutput)
    
    def param (self):
        return [(self.weight, self.dl_dw), (self.bias, self.dl_db)]
Beispiel #4
0
 def _norm_image(img: torch.Tensor
                 ) -> torch.Tensor:
     img = img.clone()
     min, max = img.min().item(), img.max().item()
     img.clamp_(min=min, max=max)
     img.add_(-min).div_(max - min + 1e-5)
     return img
Beispiel #5
0
class Linear(Module):
    """
    Linear layer implementation
    """
    def __init__(self, input_dim, output_dim, bias=True):

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.bias = bias

        # init layer weights using xavier initialization
        self.weights = Tensor(output_dim,
                              input_dim).normal_(mean=0,
                                                 std=1 / self.input_dim)
        if self.bias:
            self.bias = Tensor(output_dim).zero_()

        # initialize the tensors to accumulate the gradients during backprop
        # remember to initialize to zero at the beginning of every mini-batch step
        self.dl_dw = Tensor(self.weights.size()).zero_()
        self.dl_db = Tensor(self.bias.size()).zero_()

    def forward(self, *input):
        x = input[0]

        # saving the input of the previous layer (x_{l-1}) for the backprop algorithm
        self.input_prec_layer = input[0]

        # saving the output of this layer (s_{l}) for the backprop algorithm
        self.output_non_activated = self.weights.mv(x) + self.bias

        return self.output_non_activated

    def backward(self, *gradwrtoutput):

        # for a linear layer l, the gradwrtoutput will be the grad output
        # from the activation module, that is the product of dsigma(s_{l})
        # and the grad wrt the output of the activation function
        grad_wrt_s_l = gradwrtoutput[0]

        # compute the grad wrt the input of previous layer (x_{l-1})
        grad_wrt_input_prev_layer = self.weights.t().mv(grad_wrt_s_l)

        # compute the grad wrt the weights of this layer
        # accumulate the grad in our specific tensor
        self.dl_dw.add_(
            grad_wrt_s_l.view(-1, 1).mm(self.input_prec_layer.view(1, -1)))

        # compute grad wrt the bias term
        self.dl_db.add_(grad_wrt_s_l)

        return grad_wrt_input_prev_layer

    def param(self):
        """
        returns pair of tensors: first is a parameter tensor,
        the second is the gradient accumulator for this parameter tensor
        :return:
        """
        return [(self.weights, self.dl_dw), (self.bias, self.dl_db)]
Beispiel #6
0
def add_fourier_noise(
    idx: Tuple[int, int],
    images: Tensor,
    norm: float,
    size: Optional[Tuple[int, int]] = None,
) -> Tensor:
    """ Add Fourier noise

    Args:
        idx: index to be used
        images: original images
        norm: norm of additive noise
        size:

    Returns: images with Fourier noise

    """

    images = images.clone()

    if size is None:
        _, _, h, w = images.size()
    else:
        h, w = size

    noise = images.new_zeros(1, h, w, 2)
    noise[:, idx[0], idx[1]] = 1
    noise[:, h - 1 - idx[0], w - 1 - idx[1]] = 1
    recon = _irfft(ifft_shift(noise), 2, normalized=True,
                   onesided=False).unsqueeze(0)
    recon.div_(recon.norm(p=2)).mul_(norm)
    if size is not None:
        recon = F.interpolate(recon, images.shape[2:])
    images.add_(recon).clamp_(0, 1)
    return images
Beispiel #7
0
 def _finalize(self, A: torch.Tensor, d: DistKerContainer) -> torch.Tensor:
     A.mul_(-2.0)
     A.add_(d.sq1.to(A))
     A.add_(d.sq2.to(A).t())
     A.clamp_min_(0)
     A.sqrt_()
     return self._transform(A)
Beispiel #8
0
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Propagate the input through the Encoder block.

        Apply the Multi Head Attention block, add residual and normalize.
        Apply the Point-wise Feed Forward block, add residual and normalize.

        Parameters
        ----------
        x:
            Input tensor with shape (batch_size, K, d_model).

        Returns
        -------
            Output tensor with shape (batch_size, K, d_model).
        """
        # Self attention
        residual = x
        x = self._selfAttention(query=x, key=x, value=x)
        x = self._dopout(x)
        x.add_(residual)
        x = self._layerNorm1(x)

        # Feed forward
        residual = x
        x = self._feedForward(x)
        x = self._dopout(x)
        x.add_(residual)
        x = self._layerNorm2(x)

        return x
Beispiel #9
0
    def step(self, log_probs: torch.Tensor) -> torch.Tensor:
        """Take a single search step.

        Args:
            log_probs: (batch_size, vocab_size)
                the model's log-probabilities over the vocabulary at the current step

        Return:
            beams: (batch_size,)
                the hypothesis ids of the chosen elements, in the range [0, batch_size)
        """
        super()._step_check(log_probs)
        if self._scores is None:
            assert self._hypotheses is None
            self._init_state(log_probs.dtype, log_probs.device)

        log_probs.add_(self._scores.unsqueeze(1))
        log_probs = log_probs.flatten()
        sample_scores, samples = torch.topk(
            log_probs,
            # Take more to ensure that we will keep search_size not terminated
            min((1 + len(self._eos_ids)) * self._search_size,
                log_probs.size(0)),
            sorted=False,
        )
        sort_mask = torch.div(samples, self._vocab_size)
        samples.fmod_(self._vocab_size)

        self._init_sort_mask()
        self._update_state(samples, sample_scores, sort_mask)
        self._length += 1

        return self._sort_mask
Beispiel #10
0
def add_gaussian_noise(tensor: torch.Tensor, batch_size, sigma, clip_bound):
    """add noise to a list tensors"""
    noise_to_add = torch.zeros(tensor.shape, requires_grad=False).cuda()
    noise_to_add.normal_(0., std=clip_bound * sigma)
    noise = noise_to_add / float(batch_size)
    with torch.no_grad():
        tensor.add_(noise)
Beispiel #11
0
    def _compute_deltas(self, cluster_centers: torch.Tensor,
                        boosting_targets: torch.Tensor,
                        cluster_center_deltas: torch.Tensor,
                        cluster_center_targets: torch.Tensor,
                        prev_boosted_clusters: torch.Tensor,
                        _boost_deltas: torch.Tensor):
        """Compute the deltas - how much should the cluster centers move."""
        # Compute deltas for the non-boosted training
        torch.add(cluster_center_targets,
                  alpha=-1.,
                  other=cluster_centers,
                  out=cluster_center_deltas)

        # Set NaNs to 0
        cluster_center_deltas.masked_fill_(torch.isnan(cluster_center_deltas),
                                           0)

        # Create indexes: so that we can for each cluster center extract its target for boosting.
        # It is done even for those which should not be boosted, but they are then zeroed.
        boosting_targets_indexes = boosting_targets.unsqueeze(dim=2).expand(
            self._flock_size, self._n_cluster_centers, self._input_size)

        # Compute deltas for the boosting
        boosting_cluster_centers = torch.gather(cluster_centers,
                                                dim=1,
                                                index=boosting_targets_indexes)
        torch.add(boosting_cluster_centers,
                  alpha=-1.,
                  other=cluster_centers,
                  out=_boost_deltas)

        # Set boosting deltas to zero for clusters which are not meant to be boosted
        _boost_deltas.mul_(
            prev_boosted_clusters.unsqueeze(2).type(self._float_dtype))
        cluster_center_deltas.add_(_boost_deltas)
Beispiel #12
0
    def forward(
        self, x: torch.Tensor,
        thw: Tuple[int, int,
                   int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
        B, N, C = x.shape
        q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
                                      self.head_dim).transpose(1,
                                                               3).unbind(dim=2)

        if self.pool_k is not None:
            k = self.pool_k(k, thw)[0]
        if self.pool_v is not None:
            v = self.pool_v(v, thw)[0]
        if self.pool_q is not None:
            q, thw = self.pool_q(q, thw)

        attn = torch.matmul(self.scaler * q, k.transpose(2, 3))
        attn = attn.softmax(dim=-1)

        x = torch.matmul(attn, v)
        if self.residual_pool:
            x.add_(q)
        x = x.transpose(1, 2).reshape(B, -1, C)
        x = self.project(x)

        return x, thw
Beispiel #13
0
 def _add_noise(self, x: torch.Tensor) -> torch.Tensor:
     if self.dropout_p > 0:
         x = F.dropout3d(x, self.dropout_p, training=self.enable_dropout) if self.is_3d else \
             F.dropout2d(x, self.dropout_p, training=self.enable_dropout)
     if self.noise_lvl > 0:
         x.add_(torch.randn_like(x.detach()) * self.noise_lvl)
     return x
Beispiel #14
0
def _inplace_update(value: Tensor, stats: Tensor, momentum: Optional[float], counter: int, new_counter: int) -> Tensor:
    if momentum is None:
        value.mul_(counter / new_counter)
        value.add_(stats / new_counter)
    else:
        value.mul_(1 - momentum)
        value.add_(momentum * stats)
    return value
def _newton_raphson_step(theta: Tensor, weights: Tensor, num: Tensor,
                         lambda_: float) -> Tensor:
    den: Tensor = theta + lambda_ / weights
    func: Tensor = torch.sum(num / (den**2), dim=(0, 1)) - 1
    step: Tensor = 0.5 * (func / torch.sum(num / (den**3), dim=(0, 1)))

    theta.add_(step).clamp(min=0.)
    return func
Beispiel #16
0
def symmetrize_matrix_(inp: torch.Tensor) -> torch.Tensor:
    """Inplace version of symmetrize_matrix.

    Args:
        inp (tensor): Matrix to symmetrize.
    """
    inp.add_(inp.transpose(-1, -2).clone())
    inp.mul_(0.5)
    return inp
Beispiel #17
0
def _no_grad_trunc_normal(
    tensor: torch.Tensor,
    mean: float,
    std: float,
    a: float,
    b: float,
) -> torch.Tensor:
    """Initializes the input tensor with a truncated normal distribution.

    This method is based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf

    Args:
        tensor:
            The tensor to initialize.
        mean:
            Mean of the distribution.
        std:
            Standard deviation of the distribution.
        a:
            Minimum value of the distribution, values below will be clamped.
        b:
            Maximum value of the distribution, values above will be clamped.

    """
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn(
            "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
            "The distribution of values may be incorrect.",
            stacklevel=2
        )

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor
Beispiel #18
0
def symmetrize_potts_(weight: torch.Tensor) -> torch.Tensor:
    """Inplace version of symmetrize_potts.

    Args:
        weight (tensor): 4D tensor of shape (length, vocab, length, vocab) to symmetrize.
    """
    weight.add_(weight.permute(2, 3, 0, 1).clone())
    weight.mul_(0.5)
    return weight
Beispiel #19
0
def inplace_momentum_update(tensor: torch.Tensor, update: torch.Tensor,
                            momentum: Optional[float], counter: int,
                            new_counter: int) -> torch.Tensor:
    if momentum is None:
        tensor.mul_(counter / new_counter)
        tensor.add_(update / new_counter)
    else:
        tensor.mul_(1 - momentum)
        tensor.add_(momentum * update)
    return tensor
Beispiel #20
0
 def fn(x: bool, y: torch.Tensor):
     if x:
         b = 1
         a = y + 3
         y.add_(8)
     else:
         b = 2
         a = y + 3
     c = b + a
     return a
Beispiel #21
0
def _newton_raphson_step(theta: Tensor,
                         num: Tensor,
                         tmp: Tensor,
                         lr: float,
                         gamma: float,
                         lambda_: float) -> Tensor:
    den: Tensor = theta * tmp + lr * gamma * lambda_
    func: Tensor = (gamma ** 2) * torch.sum(num / (den ** 2), dim=(0, 1)) - 1
    funcp: Tensor = (gamma ** 2) * torch.sum((num * tmp) / (den ** 3), dim=(0, 1))

    theta.add_(0.5 * (func / funcp)).clamp_(min=0.)  # F.relu(theta + step)
    return func
Beispiel #22
0
 def inverse_input_transform(self,
                             inputs: torch.Tensor,
                             uint8=False) -> torch.Tensor:
     with torch.no_grad():
         mean, std = self.make_mean_and_std(inputs)
         # Do first one not in place to make sure it's not overwriting the original.
         inputs = inputs * std
         inputs.add_(mean)
         inputs.clamp_(0, 255)
         if uint8:
             inputs = inputs.to(torch.uint8)
         return inputs
Beispiel #23
0
def accept_reject(current_z: torch.Tensor,
                  current_v: torch.Tensor,
                  z: torch.Tensor,
                  v: torch.Tensor,
                  epsilon: torch.Tensor,
                  accept_hist: torch.Tensor,
                  hist_len: int,
                  U: Callable,
                  K: Callable,
                  max_step_size: Optional[float] = 0.5,
                  min_step_size: Optional[float] = 1e-4,
                  acceptance_threshold: Optional[float] = 0.65):
    """Accept/reject based on Hamiltonians for current and propose.

    Args:
        current_z: position *before* leap-frog steps
        current_v: speed *before* leap-frog steps
        z: position *after* leap-frog steps
        v: speed *after* leap-frog steps
        epsilon: step size of leap-frog.
        accept_hist: a tensor of size (batch_size,), each component of which is
            the number of time the trajectory is accepted
        hist_len: an int for the chain length after the current step
        U: function to compute potential energy
        K: function to compute kinetic energy
        max_step_size: maximum step size for leap-frog
        min_step_size: minimum step size for leap-frog
        acceptance_threshold: threshold acceptance rate; increase the step size
            if the chain is accepted more than this, and decrease otherwise

    Returns:
        the new state z, the adapted step size epsilon, and the updated
        accept-reject history
    """
    current_Hamil = K(current_v) + U(current_z)
    propose_Hamil = K(v) + U(z)

    prob = torch.clamp_max(torch.exp(current_Hamil - propose_Hamil), 1.)
    accept = torch.gt(prob, torch.rand_like(prob))
    z = accept.view(-1, 1) * z + ~accept.view(-1, 1) * current_z

    accept_hist.add_(accept)
    criteria = torch.gt(accept_hist / hist_len, acceptance_threshold)
    adapt = criteria * 1.02 + ~criteria * 0.98
    epsilon.mul_(adapt).clamp_(min_step_size, max_step_size)

    return z, epsilon, accept_hist
Beispiel #24
0
def adam_step(p: torch.Tensor, out_p: torch.Tensor, exp_avg: torch.Tensor, exp_avg_sq: torch.Tensor, grad: torch.Tensor,
              lr: float, beta1: float, beta2: float, eps: float, scale: float, step: int, eps_mode: int, bias_correction: int, weight_decay: float):
    assert bias_correction == 1
    assert eps_mode == 1

    grad = grad.float()
    grad.div_(scale)

    # Decay the first and second moment running average coefficient
    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
    denom = exp_avg_sq.sqrt().add_(eps)

    bias_correction1 = 1 - beta1 ** step
    bias_correction2 = 1 - beta2 ** step
    step_size = lr * math.sqrt(bias_correction2) / bias_correction1

    p.add_(exp_avg/denom + weight_decay*p.float(), alpha=-step_size)
Beispiel #25
0
def add_poisson(
    tensor: Tensor,
    lam: Union[Number, Tuple[Number, Number]],
    inplace: bool = False,
    clip: bool = True,
) -> Tuple[Tensor, Union[Number, Tensor]]:
    """Adds Poisson noise to a batch of input images.

    Args:
        tensor (Tensor): Tensor to add noise to; this should be in a B*** format, e.g. BCHW.
        lam (Union[Number, Tuple[Number, Number]]): Distribution rate parameter (lambda) for
            noise being added. If a Tuple is provided then the lambda is pulled from the
            uniform distribution between the two value is used for each batched input (B***).
        inplace (bool, optional): Whether to add the noise in-place. Defaults to False.
        clip (bool, optional): Whether to clip between image bounds (0.0-1.0 or 0-255).
            Defaults to True.

    Returns:
        Tuple[Tensor, Union[Number, Tensor]]: Tuple containing:
            * Copy of or reference to input tensor with noise added.
            * Lambda used for noise generation. This will be an array of the different
            lambda used if a range of lambda are being used.
    """
    if not inplace:
        tensor = tensor.clone()

    if isinstance(lam, (list, tuple)):
        if len(lam) == 1:
            lam = lam[0]
        else:
            assert len(lam) == 2
            (min_lam, max_lam) = lam
            uniform_generator = Uniform(min_lam, max_lam)
            shape = [tensor.shape[0]] + [1] * (len(tensor.shape) - 1)
            lam = uniform_generator.sample(shape)
    tensor.mul_(lam)
    poisson_generator = Poisson(torch.tensor(1, dtype=float))
    noise = poisson_generator.sample(tensor.shape)
    tensor.add_(noise)
    tensor.div_(lam)
    if clip:
        tensor = ssdn.utils.clip_img(tensor, inplace=True)

    return tensor, lam
Beispiel #26
0
        def trunc_normal_(
            tensor: Tensor,
            mean: float = 0.0,
            std: float = 1.0,
            a: float = -2.0,
            b: float = 2.0,
        ) -> Tensor:
            # code copied from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
            # commit: e9b369c

            # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
            def norm_cdf(x):
                # Computes standard normal cumulative distribution function
                return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

            if (mean < a - 2 * std) or (mean > b + 2 * std):
                warnings.warn(
                    "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                    "The distribution of values may be incorrect.",
                    stacklevel=2,
                )

            with torch.no_grad():
                # Values are generated by using a truncated uniform distribution and
                # then using the inverse CDF for the normal distribution.
                # Get upper and lower cdf values
                l = norm_cdf((a - mean) / std)
                u = norm_cdf((b - mean) / std)

                # Uniformly fill tensor with values from [l, u], then translate to
                # [2l-1, 2u-1].
                tensor.uniform_(2 * l - 1, 2 * u - 1)

                # Use inverse cdf transform for normal distribution to get truncated
                # standard normal
                tensor.erfinv_()

                # Transform to proper mean, std
                tensor.mul_(std * math.sqrt(2.0))
                tensor.add_(mean)

                # Clamp to ensure it's in the proper range
                tensor.clamp_(min=a, max=b)
                return tensor
Beispiel #27
0
    def _find(self, idx: torch.Tensor, sums: torch.Tensor) -> torch.Tensor:
        full_levels = int(
            torch.log2(torch.tensor(self.capacity, dtype=torch.float)))

        for _ in range(full_levels):
            idx.mul_(2).add_(1)
            left_value = self.tree[idx]
            go_right = sums.gt(left_value)
            idx.add_(go_right)
            sums.sub_(left_value * go_right)

        left_idx = idx.mul(2).add_(1)
        not_done = left_idx.lt(self.tree.size(0))
        not_done_idxs = left_idx[not_done]
        left_value = self.tree[not_done_idxs]
        go_right = sums[not_done].gt(left_value)
        idx[not_done] = go_right + not_done_idxs

        return idx
Beispiel #28
0
 def ifft_step(
     self, fframe: Tensor, tframe: Tensor, buf_wnorm: Tensor
 ) -> Tuple[Tensor, Tensor]:
     # Inverse transform with overlap add
     tframe.add_(
         torch.irfft(
             fframe * self.hann_norm,
             signal_ndim=1,
             normalized=False,
             signal_sizes=(self.n_fft.item(),),
         ).mul_(self.hann)
     )
     buf_wnorm = buf_wnorm.roll(-self.hop.item(), dims=(0,))
     buf_wnorm[-self.hop.item() :] = 0.0
     buf_wnorm += self.hann_sq
     norm = torch.clamp_min(buf_wnorm[: self.hop.item()], 1e-10)
     # norm = buf_wnorm[: self.hop.item()]
     # norm = torch.where(norm > 1e-10, norm, torch.ones_like(norm))
     tframe[: self.hop.item()].div_(norm)
     return buf_wnorm
Beispiel #29
0
def _double_backward_update(V: Tensor,
                            WH: Tensor,
                            param: Parameter,
                            beta: float,
                            gamma: float,
                            l1_reg: float,
                            l2_reg: float,
                            pos: Tensor = None):
    param.grad = None
    if beta == 2:
        output_neg = V
        output_pos = WH
    elif beta == 1:
        output_neg = V / WH.add(eps)
        output_pos = None
    elif beta == 0:
        WH_eps = WH.add(eps)
        output_pos = WH_eps.reciprocal_()
        output_neg = output_pos.square().mul_(V)
    else:
        WH_eps = WH.add(eps)
        output_neg = WH_eps.pow(beta - 2).mul_(V)
        output_pos = WH_eps.pow_(beta - 1)

    # first backward
    WH.backward(output_neg, retain_graph=pos is None)
    neg = param.grad.relu_().add_(eps)

    if pos is None:
        param.grad = None
        WH.backward(output_pos)
        pos = param.grad.relu_().add_(eps)

    if l1_reg > 0:
        pos.add_(l1_reg)
    if l2_reg > 0:
        pos = pos.add(param.data, alpha=l2_reg)
    multiplier = neg.div_(pos)
    if gamma != 1:
        multiplier.pow_(gamma)
    param.data.mul_(multiplier)
class Linear(Module):
    """ Implementation of a single fully connected linear layer module """
    def __init__(self, input_size, output_size):
        self.input = 0

        # Normal initialization of the weights (mean: 0, std: sqrt(Xavier variance))
        variance = 2 / (input_size + output_size)
        self.weight = Tensor(output_size,
                             input_size).normal_(0, math.sqrt(variance))
        self.weight_grad = Tensor(self.weight.size())

        # Normal initialization of the biases (mean: 0 mean, std: 1)
        self.bias = Tensor(output_size).normal_(0, math.sqrt(1))
        self.bias_grad = Tensor(self.bias.size())

        # For Momentum in SGD
        self.velocity_weight = torch.zeros_like(self.weight_grad)
        self.velocity_bias = torch.zeros_like(self.bias_grad)

        # For Adadelta
        self.square_avg_w = torch.zeros_like(self.weight_grad)
        self.square_avg_b = torch.zeros_like(self.bias_grad)

        self.delta_x_acc_w = torch.zeros_like(self.weight_grad)
        self.delta_x_acc_b = torch.zeros_like(self.bias_grad)

    def forward(self, input):
        self.input = input
        return self.input.mm(self.weight.t()) + self.bias

    def backward(self, output_grad):
        self.weight_grad.add_(output_grad.t().mm(self.input))
        self.bias_grad.add_(output_grad.sum(0))
        return output_grad.mm(self.weight)

    def param(self):
        return [(self.weight, self.weight_grad, self.velocity_weight,
                 self.square_avg_w, self.delta_x_acc_w),
                (self.bias, self.bias_grad, self.velocity_bias,
                 self.square_avg_b, self.delta_x_acc_b)]