def batchnorm(input, weight=None, bias=None, running_mean=None, running_var=None, training=True, eps=1e-5, momentum=0.1): ''' momentum = 1 restricts stats to the current mini-batch ''' # This hack only works when momentum is 1 and avoids needing to track running stats # by substuting dummy variables running_mean = torch.zeros(int(np.prod(np.array(input.data.size()[1])))).cuda() running_var = torch.ones(int(np.prod(np.array(input.data.size()[1])))).cuda() return F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
def forward(self, input): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input ** 2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape)
def benchmark_batch_norm(data_shape): C = data_shape[1] x = torch.rand(data_shape) mean = torch.rand(C) var = torch.rand(C) weight = torch.rand(C) bias = torch.rand(C) NITER = 10000 input_size = numpy.prod(data_shape) total_size = 2 * input_size + 4 * C for i in range(-10, NITER): if i == 0: s = time.time() F.batch_norm(x, mean, var, weight, bias) elapsed_sec = (time.time() - s) / NITER print( "batch_norm: data shape: %s, bandwidth: %.2f GB/s" % (data_shape, (total_size * 4) / elapsed_sec / 1e9) )
def forward(self, x): assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" b, c, h, w = x.size() running_mean = self.running_mean.repeat(b) running_var = self.running_var.repeat(b) # Apply instance norm x_reshaped = x.contiguous().view(1, b * c, h, w) out = F.batch_norm( x_reshaped, running_mean, running_var, self.weight, self.bias, True, self.momentum, self.eps) return out.view(b, c, h, w)
def forward(self, x): assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" b, c = x.size(0), x.size(1) running_mean = self.running_mean.repeat(b) running_var = self.running_var.repeat(b) # Apply instance norm x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) # self.weight = self.weight[:b*c] # out = F.batch_norm( # x_reshaped, running_mean, running_var, self.weight[:b*c], self.bias[:b*c], # True, self.momentum, self.eps) # self.weight = self.weight[b*c:] # self.bias = self.bias[b*c:] out = F.batch_norm(x_reshaped, running_mean, running_var, self.weight, self.bias, True, self.momentum, self.eps) return out.view(b, c, *x.size()[2:])
def forward(self, x, y): """ :param x: :param y: feature [b, self.input_size] :return: """ # Calculate class-conditional gains and biases gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) bias = self.bias(y).view(y.size(0), -1, 1, 1) out = F.batch_norm(x, self.stored_mean, self.stored_var, weight=None, bias=None, training=self.training, momentum=self.momentum, eps=self.eps) out = out * gain + bias return out
def forward(self, input_): self._check_input_dim(input_) input_ = torch.cat(torch.chunk(input_, 2, dim=1), dim=0) exponential_average_factor = 0.0 if self.training and self.track_running_stats: if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float( self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum output = F.batch_norm(input_, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps) output = torch.cat(torch.chunk(output, 2, dim=0), dim=1) return output
def forward(self, input): self._check_input_dim(input) if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and self.track_running_stats: if self.num_batches_tracked is not None: self.num_batches_tracked = self.num_batches_tracked + 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / \ float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum return F.batch_norm( input, self.running_mean, self.running_var, self.p_weight, self.p_bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps)
def forward(self, x): assert (self.weight is not None and self.bias is not None ), "Please assign weight and bias before calling AdaIN!" b, c = x.size(0), x.size(1) running_mean = self.running_mean.repeat(b) running_var = self.running_var.repeat(b) x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) out = F.batch_norm( x_reshaped, running_mean, running_var, self.weight, self.bias, True, self.momentum, self.eps, ) return out.view(b, c, *x.size()[2:])
def forward(self, input, fin_feature=None): fm = input layer = 0 for l in range(self.num_blocks): prefix = 'l' + str(l) fm = F.conv3d(fm, getattr(self, prefix + '_conv_w'), padding=1) fm = F.batch_norm(fm, getattr(self, prefix + '_running_mean'), getattr(self, prefix + '_running_var'), getattr(self, prefix + '_bn_w'), getattr(self, prefix + '_bn_b'), training=self.training) fm = F.max_pool3d(fm, 2, 2) fm = F.dropout(fm, p=self.drop_rate, training=self.training) if l != self.num_blocks - 1: fm = F.relu(fm) feature = fm.view(fm.size(0), -1) output = F.linear(feature, self.lr_fc_w, self.lr_fc_b) if fin_feature is None: return output else: return output, feature
def batchnorm(input, weight=None, bias=None, running_mean=None, running_var=None, training=True, eps=1e-5, momentum=0.1): ''' momentum = 1 restricts stats to the current mini-batch ''' # This hack only works when momentum is 1 and avoids needing to track running stats # by substuting dummy variables in_dim = input.data.size()[1] from Utils.config import USE_GPU if USE_GPU: running_mean = torch.zeros(in_dim).cuda() running_var = torch.ones(in_dim).cuda() else: running_mean = torch.zeros(in_dim) running_var = torch.ones(in_dim) return F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
def forward(self, input, weight, bias, **kwargs): self._check_input_dim(input) exponential_average_factor = self.momentum if self.training else 0.0 output = F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, self.training, exponential_average_factor, self.eps) # expand dimention to 2 if dimention is 1 if weight.dim() == 1: weight = weight.unsqueeze(0) if bias.dim() == 1: bias = bias.unsqueeze(0) size = output.size() # expand dimention to 4 to calculate affine transformation weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size) bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size) return weight * output + bias
def aten_batch_norm(inputs, attributes, scope): inp, weight, bias, running_mean, running_var = inputs[:5] training, momentum, eps = inputs[5:8] # assert training is False net = current_network() if net is not None and has_trt_tensor(inputs): running_mean = running_mean.detach().cpu().numpy() running_var = running_var.detach().cpu().numpy() weight = weight.detach().cpu().numpy() bias = bias.detach().cpu().numpy() shift = (-running_mean / np.sqrt(running_var + eps)) * weight + bias scale = weight / np.sqrt(running_var + eps) layer = net.add_scale(inp, trt.ScaleMode.CHANNEL, shift, scale, np.ones_like(shift)) output = layer.get_output(0) output.name = scope layer.name = scope return [output] res = F.batch_norm(inp, running_mean, running_var, weight, bias, bool(training), momentum, eps) return [res]
def forward(self, x, y): # Calculate class-conditional gains and biases # gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) # bias = self.bias(y).view(y.size(0), -1, 1, 1) # If using my batchnorm if self.mybn or self.cross_replica: return self.bn(x, gain=gain, bias=bias) # else: else: if self.norm_style == 'bn': out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, self.training, 0.1, self.eps) elif self.norm_style == 'in': out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, self.training, 0.1, self.eps) elif self.norm_style == 'gn': out = groupnorm(x, self.normstyle) elif self.norm_style == 'nonorm': out = x # return out * gain + bias return out
def forward(self, input): if not self.training: return batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input_shape[0], self.num_features, -1) # sum(x) and sum(x^2) N = input.size(0) * input.size(2) xsum, xsqsum = sum_square(input) # all-reduce for global sum(x) and sum(x^2) if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(xsum, xsqsum, N)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(xsum, xsqsum, N)) # forward return batchnormtrain(input, mean, 1.0/inv_std, self.weight, self.bias).view(input_shape)
def forward(self, x): if x.requires_grad: # When gradients are needed, F.batch_norm will use extra memory # because its backward op computes gradients for weight/bias as well. scale = self.weight * (self.running_var + self.eps).rsqrt() bias = self.bias - self.running_mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) return x * scale + bias else: # When gradients are not needed, F.batch_norm is a single fused op # and provide more optimization opportunities. return F.batch_norm( x, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps, )
def functional_conv_block(x: torch.Tensor, weights: torch.Tensor, biases: torch.Tensor, bn_weights, bn_biases) -> torch.Tensor: """Performs 3x3 convolution, ReLu activation, 2x2 max pooling in a functional fashion. # Arguments: x: Input Tensor for the conv block weights: Weights for the convolutional block biases: Biases for the convolutional block bn_weights: bn_biases: """ x = F.conv2d(x, weights, biases, padding=1) x = F.batch_norm(x, running_mean=None, running_var=None, weight=bn_weights, bias=bn_biases, training=True) x = F.relu(x) x = F.max_pool2d(x, kernel_size=2, stride=2) return x
def forward(self, x): if self.training and self.sync: if x.get_device() == self.devices[0]: extra = { 'is_master': True, 'master_queue': self.master_queue, 'worker_queues': self.worker_queues, 'worker_ids': self.worker_ids } else: extra = { 'is_master': False, 'master_queue': self.master_queue, 'worker_queue': self.worker_queues[self.worker_ids.index(x.get_device())] } return SyncBNFucntion.apply(x, self.weight, self.bias, self.running_mean, self.running_var, extra, self.training, self.momentum, self.eps) else: exponential_average_factor = 0.0 if self.training and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float( self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps)
def forward(self, input): lname = self._write_caffe(input[1]) input = input[0] if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum result = F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps) return result, lname
def forward(self, input, weight, bias, **kwargs): exponential_average_factor = 0.0 if self.training and self.track_running_stats: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / self.num_batches_tracked.item() else: # use exponential moving average exponential_average_factor = self.momentum output = F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps) if weight.dim() == 1: weight = weight.unsqueeze(0) if bias.dim() == 1: bias = bias.unsqueeze(0) size = output.size() weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size) bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size) return weight * output + bias
def forward_with_args(self, num_features: int_or_int_dict, eps: float, momentum: float, inputs: torch.Tensor) -> torch.Tensor: if any(isinstance(arg, dict) for arg in [eps, momentum]): raise ValueError('eps, momentum do not support weighted sampling') if isinstance(num_features, dict): num_features = self.num_features weight, bias = self.weight, self.bias running_mean, running_var = self.running_mean, self.running_var if num_features < self.num_features: weight = weight[:num_features] bias = bias[:num_features] if running_mean is not None: running_mean = running_mean[:num_features] if running_var is not None: running_var = running_var[:num_features] if self.training: bn_training = True else: bn_training = (running_mean is None) and (running_var is None) return F.batch_norm( inputs, # If buffers are not to be tracked, ensure that they won't be updated running_mean if not self.training or self.track_running_stats else None, running_var if not self.training or self.track_running_stats else None, weight, bias, bn_training, momentum, # originally exponential_average_factor in pytorch code eps, )
def _instance_norm(input, group, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=None, momentum=None, eps=None): # Repeat stored stats and affine transform params if necessary if running_mean is not None: running_mean_orig = running_mean running_mean = running_mean_orig.repeat(b) if running_var is not None: running_var_orig = running_var running_var = running_var_orig.repeat(b) # Apply instance norm input_reshaped = input.contiguous().view(1, int(b * c / group), group, *input.size()[2:]) out = F.batch_norm(input_reshaped, running_mean, running_var, weight=weight, bias=bias, training=use_input_stats, momentum=momentum, eps=eps) # Reshape back if running_mean is not None: running_mean_orig.copy_( running_mean.view(b, int(c / group)).mean(0, keepdim=False)) if running_var is not None: running_var_orig.copy_( running_var.view(b, int(c / group)).mean(0, keepdim=False)) return out.view(b, c, *input.size()[2:])
def forward_mask(self, inputs, mask=None): if mask is None or inputs.shape[1] == self.flex_bn.num_features: return self.flex_bn.forward(inputs) """ _BatchNorm official code""" if self.flex_bn.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.flex_bn.momentum if self.flex_bn.training and self.flex_bn.track_running_stats: if self.flex_bn.num_batches_tracked is not None: self.flex_bn.num_batches_tracked += 1 if self.flex_bn.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.flex_bn.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.flex_bn.momentum running_mean, running_var, weight, bias = self._select_params(mask) return F.batch_norm( inputs, running_mean, running_var, weight, bias, self.flex_bn.training or not self.flex_bn.track_running_stats, exponential_average_factor, self.flex_bn.eps)
def forward(self, input, z=None): # if input.dim() == 2, we switch to channel_last for efficient memory accessing channel_last = self.channel_last if input.dim() != 2 else True if not self.training and self.track_running_stats and not self.channel_last and not self.fuse_relu and z == None: # fall back to pytorch implementation for inference return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) else: exponential_average_factor = 0.0 if self.training and self.track_running_stats: self.num_batches_tracked += 1 if self.momentum is None: exponential_average_factor = 1.0 / float( self.num_batches_tracked) else: exponential_average_factor = self.momentum return SyncBatchnormFunction.apply( input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last, self.fuse_relu)
def forward(self, x): if self.training and self.sync: return DistributedSyncBNFucntion.apply( x, self.weight, self.bias, self.running_mean, self.running_var, self.training, self.momentum, self.eps, self.sync) else: exponential_average_factor = 0.0 if self.training and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float( self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps)
def forward(self, x, w): self._check_input_dim(x) exponential_average_factor = 0.0 if self.training and self.track_running_stats: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / self.num_batches_tracked.item( ) else: # use exponential moving average exponential_average_factor = self.momentum output = F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps) size = output.size() weight, bias = self.weight_proj(w) + 1, self.bias_proj(w) weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size) bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size) return weight * output + bias
def forward(self, input): self._check_input_dim(input) if not self.training: return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False) if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and self.track_running_stats: if self.num_batches_tracked is not None: self.num_batches_tracked = self.num_batches_tracked + 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float( self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum if self.training: bn_training = True else: bn_training = (self.running_mean is None) and (self.running_var is None) batch_mean = input.mean(dim=0) batch_var = ((input - self.running_mean) * (input - self.running_mean)).mean(dim=0) output = (input - self.running_mean) / (self.running_var + self.eps).sqrt() if self.training: with torch.no_grad(): self.running_mean[...] = batch_mean self.running_var[...] = batch_var result = output * self.weight + self.bias return result
def forward(self, input): # TODO: weight is not quantized self._check_input_dim(input) if config.quantize_activation: qinput = self.quantize_input(input) else: qinput = input # if config.quantize_weights: # qweight = quantize(self.weight, config.bias_preconditioner()) # qbias = quantize(self.bias, config.bias_preconditioner()) # else: # qweight = self.weight # qbias = self.bias qweight = self.weight qbias = self.bias # exponential_average_factor is set to self.momentum # (when it is available) only so that if gets updated # in ONNX graph when this node is exported to ONNX. if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and self.track_running_stats: if self.num_batches_tracked is not None: self.num_batches_tracked = self.num_batches_tracked + 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float( self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum return F.batch_norm(input, self.running_mean, self.running_var, qweight, qbias, self.training or not self.track_running_stats, exponential_average_factor, self.eps)
def forward(self, bn_weight, bn_bias, *inputs): if self.training: # Save the current BN statistics for later self.prev_running_mean.copy_(self.running_mean) self.prev_running_var.copy_(self.running_var) # Create tensors that use shared allocations # One for the concatenation output (bn_input) # One for the ReLU output (relu_output) all_num_channels = [input.size(1) for input in inputs] size = list(inputs[0].size()) for num_channels in all_num_channels[1:]: size[1] += num_channels storage = self.shared_allocation_1.storage_for(inputs[0]) bn_input_var = Variable(type(inputs[0])(storage).resize_(size), volatile=True) relu_output = type(inputs[0])(storage).resize_(size) # Create variable, using existing storage torch.cat(inputs, dim=1, out=bn_input_var.data) # Do batch norm bn_weight_var = Variable(bn_weight) bn_bias_var = Variable(bn_bias) bn_output_var = F.batch_norm(bn_input_var, self.running_mean, self.running_var, bn_weight_var, bn_bias_var, training=self.training, momentum=self.momentum, eps=self.eps) # Do ReLU - and have the output be in the intermediate storage torch.clamp(bn_output_var.data, 0, 1e100, out=relu_output) self.save_for_backward(bn_weight, bn_bias, *inputs) return relu_output
def prepare_backward(self): bn_weight, bn_bias = self.saved_tensors[:2] inputs = self.saved_tensors[2:] # Temporarily reset batch norm statistics self.curr_running_mean.copy_(self.running_mean) self.curr_running_var.copy_(self.running_var) self.running_mean.copy_(self.prev_running_mean) self.running_var.copy_(self.prev_running_var) # Re-do the forward pass to re-populate the shared storage all_num_channels = [input.size(1) for input in inputs] size = list(inputs[0].size()) for num_channels in all_num_channels[1:]: size[1] += num_channels storage1 = self.shared_allocation_1.storage_for(inputs[0]) self.bn_input_var = Variable(type(inputs[0])(storage1).resize_(size), requires_grad=True) storage2 = self.shared_allocation_2.storage_for(inputs[0]) self.relu_output = type(inputs[0])(storage2).resize_(size) # Create variable, using existing storage torch.cat(inputs, dim=1, out=self.bn_input_var.data) # Do batch norm self.bn_weight_var = Variable(bn_weight, requires_grad=True) self.bn_bias_var = Variable(bn_bias, requires_grad=True) self.bn_output_var = F.batch_norm(self.bn_input_var, self.running_mean, self.running_var, self.bn_weight_var, self.bn_bias_var, training=self.training, momentum=self.momentum, eps=self.eps) # Do ReLU torch.clamp(self.bn_output_var.data, 0, 1e100, out=self.relu_output)
def forward(self, input): self._check_input_dim(input) if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and self.track_running_stats: if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: exponential_average_factor = 1.0 / float( self.num_batches_tracked) else: exponential_average_factor = self.momentum if self.affine: if self.weight_eps is None: weight = self.weight_mu + torch.exp( self.weight_log_sigma) * torch.randn_like( self.weight_log_sigma) bias = self.bias_mu + torch.exp( self.bias_log_sigma) * torch.randn_like( self.bias_log_sigma) else: weight = self.weight_mu + torch.exp( self.weight_log_sigma) * self.weight_eps bias = self.bias_mu + torch.exp( self.bias_log_sigma) * self.bias_eps else: weight = None bias = None return F.batch_norm(input, self.running_mean, self.running_var, weight, bias, self.training or not self.track_running_stats, exponential_average_factor, self.eps)
def forward(self, input, weight=None, bias=None): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input**2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master( _ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave( _ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: if weight is None or bias is None: weight = self.weight bias = self.bias # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft( inv_std * weight) + _unsqueeze_ft(bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape)
def forward(self, input): self._check_input_dim(input) # hack to work around model.eval() issue if not self.training: self.eval_momentum = 0 # set the momentum to zero when the model is validating if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training else self.eval_momentum if self.track_running_stats: if self.num_batches_tracked is not None: self.num_batches_tracked = self.num_batches_tracked + 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum if self.training else self.eval_momentum return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, training=True, momentum=exponential_average_factor, eps=self.eps) # set training to True so it calculates the norm of the batch
def forward(self, x): N, C, H, W = x.shape x = functional.batch_norm(x.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var, self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits), True, self.momentum, self.eps).view(N, C, H, W) # x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, # self.training, self.momentum, self.eps) if self.activation == "relu": return functional.relu(x, inplace=True) elif self.activation == "leaky_relu": return functional.leaky_relu(x, negative_slope=self.activation_param, inplace=True) elif self.activation == "elu": return functional.elu(x, alpha=self.activation_param, inplace=True) elif self.activation == "identity": return x else: raise RuntimeError("Unknown activation function {}".format( self.activation))
def batch_norm(x, params, base, mode): return F.batch_norm(x, weight=params[base + '.weight'], bias=params[base + '.bias'], running_mean=params[base + '.running_mean'], running_var=params[base + '.running_var'], training=mode)
def forward(self, input): self._check_input_dim(input) return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps)