Example #1
0
def batchnorm(input, weight=None, bias=None, running_mean=None, running_var=None, training=True, eps=1e-5, momentum=0.1):
    ''' momentum = 1 restricts stats to the current mini-batch '''
    # This hack only works when momentum is 1 and avoids needing to track running stats
    # by substuting dummy variables
    running_mean = torch.zeros(int(np.prod(np.array(input.data.size()[1])))).cuda()
    running_var = torch.ones(int(np.prod(np.array(input.data.size()[1])))).cuda()
    return F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
    def forward(self, input):
        # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
        if not (self._is_parallel and self.training):
            return F.batch_norm(
                input, self.running_mean, self.running_var, self.weight, self.bias,
                self.training, self.momentum, self.eps)

        # Resize the input to (B, C, -1).
        input_shape = input.size()
        input = input.view(input.size(0), self.num_features, -1)

        # Compute the sum and square-sum.
        sum_size = input.size(0) * input.size(2)
        input_sum = _sum_ft(input)
        input_ssum = _sum_ft(input ** 2)

        # Reduce-and-broadcast the statistics.
        if self._parallel_id == 0:
            mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
        else:
            mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))

        # Compute the output.
        if self.affine:
            # MJY:: Fuse the multiplication for speed.
            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
        else:
            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)

        # Reshape it.
        return output.view(input_shape)
def benchmark_batch_norm(data_shape):
    C = data_shape[1]
    x = torch.rand(data_shape)
    mean = torch.rand(C)
    var = torch.rand(C)
    weight = torch.rand(C)
    bias = torch.rand(C)
    NITER = 10000
    input_size = numpy.prod(data_shape)
    total_size = 2 * input_size + 4 * C
    for i in range(-10, NITER):
        if i == 0:
            s = time.time()
        F.batch_norm(x, mean, var, weight, bias)
    elapsed_sec = (time.time() - s) / NITER
    print(
        "batch_norm: data shape: %s, bandwidth: %.2f GB/s"
        % (data_shape, (total_size * 4) / elapsed_sec / 1e9)
    )
Example #4
0
    def forward(self, x):
        assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
        b, c, h, w = x.size()
        running_mean = self.running_mean.repeat(b)
        running_var = self.running_var.repeat(b)

        # Apply instance norm
        x_reshaped = x.contiguous().view(1, b * c, h, w)

        out = F.batch_norm(
            x_reshaped, running_mean, running_var, self.weight, self.bias,
            True, self.momentum, self.eps)

        return out.view(b, c, h, w)
Example #5
0
    def forward(self, x):
        assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
        b, c = x.size(0), x.size(1)
        running_mean = self.running_mean.repeat(b)
        running_var = self.running_var.repeat(b)

        # Apply instance norm
        x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])

        # self.weight = self.weight[:b*c]

        # out = F.batch_norm(
        #     x_reshaped, running_mean, running_var, self.weight[:b*c], self.bias[:b*c],
        #     True, self.momentum, self.eps)
        # self.weight = self.weight[b*c:]
        # self.bias = self.bias[b*c:]

        out = F.batch_norm(x_reshaped, running_mean, running_var, self.weight,
                           self.bias, True, self.momentum, self.eps)

        return out.view(b, c, *x.size()[2:])
Example #6
0
    def forward(self, x, y):
        """

    :param x:
    :param y: feature [b, self.input_size]
    :return:
    """
        # Calculate class-conditional gains and biases
        gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
        bias = self.bias(y).view(y.size(0), -1, 1, 1)

        out = F.batch_norm(x,
                           self.stored_mean,
                           self.stored_var,
                           weight=None,
                           bias=None,
                           training=self.training,
                           momentum=self.momentum,
                           eps=self.eps)
        out = out * gain + bias
        return out
Example #7
0
    def forward(self, input_):
        self._check_input_dim(input_)

        input_ = torch.cat(torch.chunk(input_, 2, dim=1), dim=0)
        exponential_average_factor = 0.0
        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(
                        self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        output = F.batch_norm(input_, self.running_mean, self.running_var,
                              self.weight, self.bias, self.training
                              or not self.track_running_stats,
                              exponential_average_factor, self.eps)
        output = torch.cat(torch.chunk(output, 2, dim=0), dim=1)

        return output
Example #8
0
    def forward(self, input):
        self._check_input_dim(input)

        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / \
                        float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        return F.batch_norm(
            input, self.running_mean, self.running_var, self.p_weight, self.p_bias,
            self.training or not self.track_running_stats,
            exponential_average_factor, self.eps)
Example #9
0
    def forward(self, x):
        assert (self.weight is not None and self.bias is not None
                ), "Please assign weight and bias before calling AdaIN!"
        b, c = x.size(0), x.size(1)
        running_mean = self.running_mean.repeat(b)
        running_var = self.running_var.repeat(b)

        x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])

        out = F.batch_norm(
            x_reshaped,
            running_mean,
            running_var,
            self.weight,
            self.bias,
            True,
            self.momentum,
            self.eps,
        )

        return out.view(b, c, *x.size()[2:])
Example #10
0
 def forward(self, input, fin_feature=None):
     fm = input
     layer = 0
     for l in range(self.num_blocks):
         prefix = 'l' + str(l)
         fm = F.conv3d(fm, getattr(self, prefix + '_conv_w'), padding=1)
         fm = F.batch_norm(fm,
                           getattr(self, prefix + '_running_mean'),
                           getattr(self, prefix + '_running_var'),
                           getattr(self, prefix + '_bn_w'),
                           getattr(self, prefix + '_bn_b'),
                           training=self.training)
         fm = F.max_pool3d(fm, 2, 2)
         fm = F.dropout(fm, p=self.drop_rate, training=self.training)
         if l != self.num_blocks - 1: fm = F.relu(fm)
     feature = fm.view(fm.size(0), -1)
     output = F.linear(feature, self.lr_fc_w, self.lr_fc_b)
     if fin_feature is None:
         return output
     else:
         return output, feature
Example #11
0
def batchnorm(input,
              weight=None,
              bias=None,
              running_mean=None,
              running_var=None,
              training=True,
              eps=1e-5,
              momentum=0.1):
    ''' momentum = 1 restricts stats to the current mini-batch '''
    # This hack only works when momentum is 1 and avoids needing to track running stats
    # by substuting dummy variables
    in_dim = input.data.size()[1]
    from Utils.config import USE_GPU
    if USE_GPU:
        running_mean = torch.zeros(in_dim).cuda()
        running_var = torch.ones(in_dim).cuda()
    else:
        running_mean = torch.zeros(in_dim)
        running_var = torch.ones(in_dim)
    return F.batch_norm(input, running_mean, running_var, weight, bias,
                        training, momentum, eps)
Example #12
0
    def forward(self, input, weight, bias, **kwargs):
        self._check_input_dim(input)

        exponential_average_factor = self.momentum if self.training else 0.0

        output = F.batch_norm(input, self.running_mean, self.running_var,
                              self.weight, self.bias, self.training,
                              exponential_average_factor, self.eps)

        # expand dimention to 2 if dimention is 1
        if weight.dim() == 1:
            weight = weight.unsqueeze(0)
        if bias.dim() == 1:
            bias = bias.unsqueeze(0)
        size = output.size()

        # expand dimention to 4 to calculate affine transformation
        weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
        bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)

        return weight * output + bias
Example #13
0
def aten_batch_norm(inputs, attributes, scope):
    inp, weight, bias, running_mean, running_var = inputs[:5]
    training, momentum, eps = inputs[5:8]
    # assert training is False
    net = current_network()
    if net is not None and has_trt_tensor(inputs):
        running_mean = running_mean.detach().cpu().numpy()
        running_var = running_var.detach().cpu().numpy()
        weight = weight.detach().cpu().numpy()
        bias = bias.detach().cpu().numpy()
        shift = (-running_mean / np.sqrt(running_var + eps)) * weight + bias
        scale = weight / np.sqrt(running_var + eps)
        layer = net.add_scale(inp, trt.ScaleMode.CHANNEL, shift, scale,
                              np.ones_like(shift))
        output = layer.get_output(0)
        output.name = scope
        layer.name = scope
        return [output]
    res = F.batch_norm(inp, running_mean, running_var, weight, bias,
                       bool(training), momentum, eps)
    return [res]
Example #14
0
 def forward(self, x, y):
   # Calculate class-conditional gains and biases
   # gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
   # bias = self.bias(y).view(y.size(0), -1, 1, 1)
   # If using my batchnorm
   if self.mybn or self.cross_replica:
     return self.bn(x, gain=gain, bias=bias)
   # else:
   else:
     if self.norm_style == 'bn':
       out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
                         self.training, 0.1, self.eps)
     elif self.norm_style == 'in':
       out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
                         self.training, 0.1, self.eps)
     elif self.norm_style == 'gn':
       out = groupnorm(x, self.normstyle)
     elif self.norm_style == 'nonorm':
       out = x
   #   return out * gain + bias
     return out
Example #15
0
    def forward(self, input):
        if not self.training:
            return batch_norm(
                input, self.running_mean, self.running_var, self.weight, self.bias,
                self.training, self.momentum, self.eps)

        # Resize the input to (B, C, -1).
        input_shape = input.size()
        input = input.view(input_shape[0], self.num_features, -1)

        # sum(x) and sum(x^2)
        N = input.size(0) * input.size(2)
        xsum, xsqsum = sum_square(input)

        # all-reduce for global sum(x) and sum(x^2)
        if self._parallel_id == 0:
            mean, inv_std = self._sync_master.run_master(_ChildMessage(xsum, xsqsum, N))
        else:
            mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(xsum, xsqsum, N))
        # forward
        return batchnormtrain(input, mean, 1.0/inv_std, self.weight, self.bias).view(input_shape)
Example #16
0
 def forward(self, x):
     if x.requires_grad:
         # When gradients are needed, F.batch_norm will use extra memory
         # because its backward op computes gradients for weight/bias as well.
         scale = self.weight * (self.running_var + self.eps).rsqrt()
         bias = self.bias - self.running_mean * scale
         scale = scale.reshape(1, -1, 1, 1)
         bias = bias.reshape(1, -1, 1, 1)
         return x * scale + bias
     else:
         # When gradients are not needed, F.batch_norm is a single fused op
         # and provide more optimization opportunities.
         return F.batch_norm(
             x,
             self.running_mean,
             self.running_var,
             self.weight,
             self.bias,
             training=False,
             eps=self.eps,
         )
Example #17
0
def functional_conv_block(x: torch.Tensor, weights: torch.Tensor,
                          biases: torch.Tensor, bn_weights,
                          bn_biases) -> torch.Tensor:
    """Performs 3x3 convolution, ReLu activation, 2x2 max pooling in a functional fashion.
    # Arguments:
        x: Input Tensor for the conv block
        weights: Weights for the convolutional block
        biases: Biases for the convolutional block
        bn_weights:
        bn_biases:
    """
    x = F.conv2d(x, weights, biases, padding=1)
    x = F.batch_norm(x,
                     running_mean=None,
                     running_var=None,
                     weight=bn_weights,
                     bias=bn_biases,
                     training=True)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)
    return x
Example #18
0
    def forward(self, x):
        if self.training and self.sync:
            if x.get_device() == self.devices[0]:
                extra = {
                    'is_master': True,
                    'master_queue': self.master_queue,
                    'worker_queues': self.worker_queues,
                    'worker_ids': self.worker_ids
                }
            else:
                extra = {
                    'is_master':
                    False,
                    'master_queue':
                    self.master_queue,
                    'worker_queue':
                    self.worker_queues[self.worker_ids.index(x.get_device())]
                }

            return SyncBNFucntion.apply(x, self.weight, self.bias,
                                        self.running_mean, self.running_var,
                                        extra, self.training, self.momentum,
                                        self.eps)
        else:
            exponential_average_factor = 0.0

            if self.training and self.track_running_stats:
                # TODO: if statement only here to tell the jit to skip emitting this when it is None
                if self.num_batches_tracked is not None:
                    self.num_batches_tracked += 1
                    if self.momentum is None:  # use cumulative moving average
                        exponential_average_factor = 1.0 / float(
                            self.num_batches_tracked)
                    else:  # use exponential moving average
                        exponential_average_factor = self.momentum

            return F.batch_norm(x, self.running_mean, self.running_var,
                                self.weight, self.bias, self.training
                                or not self.track_running_stats,
                                exponential_average_factor, self.eps)
Example #19
0
	def forward(self, input):
		lname = self._write_caffe(input[1])
		input = input[0]
		if self.momentum is None:
			exponential_average_factor = 0.0
		else:
			exponential_average_factor = self.momentum

		if self.training and self.track_running_stats:
			# TODO: if statement only here to tell the jit to skip emitting this when it is None
			if self.num_batches_tracked is not None:
				self.num_batches_tracked += 1
				if self.momentum is None:  # use cumulative moving average
					exponential_average_factor = 1.0 / float(self.num_batches_tracked)
				else:  # use exponential moving average
					exponential_average_factor = self.momentum

		result =  F.batch_norm(
			input, self.running_mean, self.running_var, self.weight, self.bias,
			self.training or not self.track_running_stats,
			exponential_average_factor, self.eps)
		return result, lname
    def forward(self, input, weight, bias, **kwargs):
        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            self.num_batches_tracked += 1
            if self.momentum is None:  # use cumulative moving average
                exponential_average_factor = 1.0 / self.num_batches_tracked.item()
            else:  # use exponential moving average
                exponential_average_factor = self.momentum

        output = F.batch_norm(input, self.running_mean, self.running_var,
                              self.weight, self.bias,
                              self.training or not self.track_running_stats,
                              exponential_average_factor, self.eps)
        if weight.dim() == 1:
            weight = weight.unsqueeze(0)
        if bias.dim() == 1:
            bias = bias.unsqueeze(0)
        size = output.size()
        weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
        bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
        return weight * output + bias
Example #21
0
    def forward_with_args(self, num_features: int_or_int_dict, eps: float,
                          momentum: float,
                          inputs: torch.Tensor) -> torch.Tensor:

        if any(isinstance(arg, dict) for arg in [eps, momentum]):
            raise ValueError('eps, momentum do not support weighted sampling')

        if isinstance(num_features, dict):
            num_features = self.num_features

        weight, bias = self.weight, self.bias
        running_mean, running_var = self.running_mean, self.running_var

        if num_features < self.num_features:
            weight = weight[:num_features]
            bias = bias[:num_features]
            if running_mean is not None:
                running_mean = running_mean[:num_features]
            if running_var is not None:
                running_var = running_var[:num_features]

        if self.training:
            bn_training = True
        else:
            bn_training = (running_mean is None) and (running_var is None)

        return F.batch_norm(
            inputs,
            # If buffers are not to be tracked, ensure that they won't be updated
            running_mean
            if not self.training or self.track_running_stats else None,
            running_var
            if not self.training or self.track_running_stats else None,
            weight,
            bias,
            bn_training,
            momentum,  # originally exponential_average_factor in pytorch code
            eps,
        )
Example #22
0
    def _instance_norm(input,
                       group,
                       running_mean=None,
                       running_var=None,
                       weight=None,
                       bias=None,
                       use_input_stats=None,
                       momentum=None,
                       eps=None):
        # Repeat stored stats and affine transform params if necessary
        if running_mean is not None:
            running_mean_orig = running_mean
            running_mean = running_mean_orig.repeat(b)
        if running_var is not None:
            running_var_orig = running_var
            running_var = running_var_orig.repeat(b)

        # Apply instance norm
        input_reshaped = input.contiguous().view(1, int(b * c / group), group,
                                                 *input.size()[2:])

        out = F.batch_norm(input_reshaped,
                           running_mean,
                           running_var,
                           weight=weight,
                           bias=bias,
                           training=use_input_stats,
                           momentum=momentum,
                           eps=eps)

        # Reshape back
        if running_mean is not None:
            running_mean_orig.copy_(
                running_mean.view(b, int(c / group)).mean(0, keepdim=False))
        if running_var is not None:
            running_var_orig.copy_(
                running_var.view(b, int(c / group)).mean(0, keepdim=False))

        return out.view(b, c, *input.size()[2:])
Example #23
0
    def forward_mask(self, inputs, mask=None):
        if mask is None or inputs.shape[1] == self.flex_bn.num_features:
            return self.flex_bn.forward(inputs)
        
        """ _BatchNorm official code"""
        if self.flex_bn.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.flex_bn.momentum

        if self.flex_bn.training and self.flex_bn.track_running_stats:
            if self.flex_bn.num_batches_tracked is not None:
                self.flex_bn.num_batches_tracked += 1
                if self.flex_bn.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.flex_bn.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.flex_bn.momentum
        running_mean, running_var, weight, bias = self._select_params(mask)
        return F.batch_norm(
            inputs, running_mean, running_var, weight, bias,
            self.flex_bn.training or not self.flex_bn.track_running_stats,
            exponential_average_factor, self.flex_bn.eps)
    def forward(self, input, z=None):
        # if input.dim() == 2, we switch to channel_last for efficient memory accessing
        channel_last = self.channel_last if input.dim() != 2 else True

        if not self.training and self.track_running_stats and not self.channel_last and not self.fuse_relu and z == None:
            # fall back to pytorch implementation for inference
            return F.batch_norm(input, self.running_mean, self.running_var,
                                self.weight, self.bias, False, 0.0, self.eps)
        else:
            exponential_average_factor = 0.0
            if self.training and self.track_running_stats:
                self.num_batches_tracked += 1
                if self.momentum is None:
                    exponential_average_factor = 1.0 / float(
                        self.num_batches_tracked)
                else:
                    exponential_average_factor = self.momentum
            return SyncBatchnormFunction.apply(
                input, z, self.weight, self.bias, self.running_mean,
                self.running_var, self.eps, self.training
                or not self.track_running_stats, exponential_average_factor,
                self.process_group, self.channel_last, self.fuse_relu)
Example #25
0
    def forward(self, x):
        if self.training and self.sync:
            return DistributedSyncBNFucntion.apply(
                x, self.weight, self.bias, self.running_mean, self.running_var,
                self.training, self.momentum, self.eps, self.sync)
        else:
            exponential_average_factor = 0.0

            if self.training and self.track_running_stats:
                # TODO: if statement only here to tell the jit to skip emitting this when it is None
                if self.num_batches_tracked is not None:
                    self.num_batches_tracked += 1
                    if self.momentum is None:  # use cumulative moving average
                        exponential_average_factor = 1.0 / float(
                            self.num_batches_tracked)
                    else:  # use exponential moving average
                        exponential_average_factor = self.momentum

            return F.batch_norm(x, self.running_mean, self.running_var,
                                self.weight, self.bias, self.training
                                or not self.track_running_stats,
                                exponential_average_factor, self.eps)
Example #26
0
    def forward(self, x, w):
        self._check_input_dim(x)
        exponential_average_factor = 0.0
        if self.training and self.track_running_stats:
            self.num_batches_tracked += 1
            if self.momentum is None:  # use cumulative moving average
                exponential_average_factor = 1.0 / self.num_batches_tracked.item(
                )
            else:  # use exponential moving average
                exponential_average_factor = self.momentum

        output = F.batch_norm(x, self.running_mean, self.running_var,
                              self.weight, self.bias, self.training
                              or not self.track_running_stats,
                              exponential_average_factor, self.eps)

        size = output.size()
        weight, bias = self.weight_proj(w) + 1, self.bias_proj(w)
        weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
        bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)

        return weight * output + bias
Example #27
0
    def forward(self, input):
        self._check_input_dim(input)

        if not self.training:
            return F.batch_norm(input, self.running_mean, self.running_var,
                                self.weight, self.bias, False)

        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(
                        self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is
                                                           None)

        batch_mean = input.mean(dim=0)
        batch_var = ((input - self.running_mean) *
                     (input - self.running_mean)).mean(dim=0)
        output = (input - self.running_mean) / (self.running_var +
                                                self.eps).sqrt()
        if self.training:
            with torch.no_grad():
                self.running_mean[...] = batch_mean
                self.running_var[...] = batch_var
        result = output * self.weight + self.bias
        return result
Example #28
0
    def forward(self, input):  # TODO: weight is not quantized
        self._check_input_dim(input)
        if config.quantize_activation:
            qinput = self.quantize_input(input)
        else:
            qinput = input

        # if config.quantize_weights:
        #     qweight = quantize(self.weight, config.bias_preconditioner())
        #     qbias = quantize(self.bias, config.bias_preconditioner())
        # else:
        #     qweight = self.weight
        #     qbias = self.bias

        qweight = self.weight
        qbias = self.bias

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that if gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(
                        self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        return F.batch_norm(input, self.running_mean, self.running_var,
                            qweight, qbias, self.training
                            or not self.track_running_stats,
                            exponential_average_factor, self.eps)
    def forward(self, bn_weight, bn_bias, *inputs):
        if self.training:
            # Save the current BN statistics for later
            self.prev_running_mean.copy_(self.running_mean)
            self.prev_running_var.copy_(self.running_var)

        # Create tensors that use shared allocations
        # One for the concatenation output (bn_input)
        # One for the ReLU output (relu_output)
        all_num_channels = [input.size(1) for input in inputs]
        size = list(inputs[0].size())
        for num_channels in all_num_channels[1:]:
            size[1] += num_channels
        storage = self.shared_allocation_1.storage_for(inputs[0])
        bn_input_var = Variable(type(inputs[0])(storage).resize_(size),
                                volatile=True)
        relu_output = type(inputs[0])(storage).resize_(size)

        # Create variable, using existing storage
        torch.cat(inputs, dim=1, out=bn_input_var.data)

        # Do batch norm
        bn_weight_var = Variable(bn_weight)
        bn_bias_var = Variable(bn_bias)
        bn_output_var = F.batch_norm(bn_input_var,
                                     self.running_mean,
                                     self.running_var,
                                     bn_weight_var,
                                     bn_bias_var,
                                     training=self.training,
                                     momentum=self.momentum,
                                     eps=self.eps)

        # Do ReLU - and have the output be in the intermediate storage
        torch.clamp(bn_output_var.data, 0, 1e100, out=relu_output)

        self.save_for_backward(bn_weight, bn_bias, *inputs)
        return relu_output
    def prepare_backward(self):
        bn_weight, bn_bias = self.saved_tensors[:2]
        inputs = self.saved_tensors[2:]

        # Temporarily reset batch norm statistics
        self.curr_running_mean.copy_(self.running_mean)
        self.curr_running_var.copy_(self.running_var)
        self.running_mean.copy_(self.prev_running_mean)
        self.running_var.copy_(self.prev_running_var)

        # Re-do the forward pass to re-populate the shared storage
        all_num_channels = [input.size(1) for input in inputs]
        size = list(inputs[0].size())
        for num_channels in all_num_channels[1:]:
            size[1] += num_channels
        storage1 = self.shared_allocation_1.storage_for(inputs[0])
        self.bn_input_var = Variable(type(inputs[0])(storage1).resize_(size),
                                     requires_grad=True)
        storage2 = self.shared_allocation_2.storage_for(inputs[0])
        self.relu_output = type(inputs[0])(storage2).resize_(size)

        # Create variable, using existing storage
        torch.cat(inputs, dim=1, out=self.bn_input_var.data)

        # Do batch norm
        self.bn_weight_var = Variable(bn_weight, requires_grad=True)
        self.bn_bias_var = Variable(bn_bias, requires_grad=True)
        self.bn_output_var = F.batch_norm(self.bn_input_var,
                                          self.running_mean,
                                          self.running_var,
                                          self.bn_weight_var,
                                          self.bn_bias_var,
                                          training=self.training,
                                          momentum=self.momentum,
                                          eps=self.eps)

        # Do ReLU
        torch.clamp(self.bn_output_var.data, 0, 1e100, out=self.relu_output)
Example #31
0
    def forward(self, input):
        self._check_input_dim(input)

        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:
                    exponential_average_factor = 1.0 / float(
                        self.num_batches_tracked)
                else:
                    exponential_average_factor = self.momentum

        if self.affine:
            if self.weight_eps is None:
                weight = self.weight_mu + torch.exp(
                    self.weight_log_sigma) * torch.randn_like(
                        self.weight_log_sigma)
                bias = self.bias_mu + torch.exp(
                    self.bias_log_sigma) * torch.randn_like(
                        self.bias_log_sigma)
            else:
                weight = self.weight_mu + torch.exp(
                    self.weight_log_sigma) * self.weight_eps
                bias = self.bias_mu + torch.exp(
                    self.bias_log_sigma) * self.bias_eps
        else:
            weight = None
            bias = None

        return F.batch_norm(input, self.running_mean, self.running_var, weight,
                            bias, self.training
                            or not self.track_running_stats,
                            exponential_average_factor, self.eps)
Example #32
0
    def forward(self, input, weight=None, bias=None):
        # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
        if not (self._is_parallel and self.training):
            return F.batch_norm(input, self.running_mean, self.running_var,
                                self.weight, self.bias, self.training,
                                self.momentum, self.eps)

        # Resize the input to (B, C, -1).
        input_shape = input.size()
        input = input.view(input.size(0), self.num_features, -1)

        # Compute the sum and square-sum.
        sum_size = input.size(0) * input.size(2)
        input_sum = _sum_ft(input)
        input_ssum = _sum_ft(input**2)

        # Reduce-and-broadcast the statistics.
        if self._parallel_id == 0:
            mean, inv_std = self._sync_master.run_master(
                _ChildMessage(input_sum, input_ssum, sum_size))
        else:
            mean, inv_std = self._slave_pipe.run_slave(
                _ChildMessage(input_sum, input_ssum, sum_size))

        # Compute the output.
        if self.affine:
            if weight is None or bias is None:
                weight = self.weight
                bias = self.bias

            # MJY:: Fuse the multiplication for speed.
            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(
                inv_std * weight) + _unsqueeze_ft(bias)
        else:
            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)

        # Reshape it.
        return output.view(input_shape)
Example #33
0
    def forward(self, input):
        self._check_input_dim(input)

        # hack to work around model.eval() issue
        if not self.training:
            self.eval_momentum = 0  # set the momentum to zero when the model is validating

        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum if self.training else self.eval_momentum

        if self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum if self.training else self.eval_momentum

        return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
            training=True, momentum=exponential_average_factor, eps=self.eps)  # set training to True so it calculates the norm of the batch
Example #34
0
    def forward(self, x):
        N, C, H, W = x.shape
        x = functional.batch_norm(x.view(-1, C * self.num_splits, H, W),
                                  self.running_mean, self.running_var,
                                  self.weight.repeat(self.num_splits),
                                  self.bias.repeat(self.num_splits), True,
                                  self.momentum, self.eps).view(N, C, H, W)
        #         x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
        #                                   self.training, self.momentum, self.eps)

        if self.activation == "relu":
            return functional.relu(x, inplace=True)
        elif self.activation == "leaky_relu":
            return functional.leaky_relu(x,
                                         negative_slope=self.activation_param,
                                         inplace=True)
        elif self.activation == "elu":
            return functional.elu(x, alpha=self.activation_param, inplace=True)
        elif self.activation == "identity":
            return x
        else:
            raise RuntimeError("Unknown activation function {}".format(
                self.activation))
def batch_norm(x, params, base, mode):
    return F.batch_norm(x, weight=params[base + '.weight'],
                        bias=params[base + '.bias'],
                        running_mean=params[base + '.running_mean'],
                        running_var=params[base + '.running_var'],
                        training=mode)
Example #36
0
 def forward(self, input):
     self._check_input_dim(input)
     return F.batch_norm(
         input, self.running_mean, self.running_var, self.weight, self.bias,
         self.training, self.momentum, self.eps)