Exemplo n.º 1
0
 def forward(self, x, y=None):
     if self.cross_replica or self.mybn:
         gain = self.gain.view(1, -1, 1, 1)
         bias = self.bias.view(1, -1, 1, 1)
         return self.bn(x, gain=gain, bias=bias)
     else:
         return F.batch_norm(x, self.stored_mean, self.stored_var,
                             self.gain, self.bias, self.training,
                             self.momentum, self.eps)
Exemplo n.º 2
0
 def forward(self, x, y):
     # Calculate class-conditional gains and biases
     gain = torch.Tensor(1 + self.gain(y)).view(y.size(0), -1, 1, 1)
     bias = torch.Tensor(self.bias(y)).view(y.size(0), -1, 1, 1)
     # If using my batchnorm
     if self.mybn or self.cross_replica:
         return self.bn(x, gain=gain, bias=bias)
     # else:
     else:
         if self.norm_style == 'bn':
             out = F.batch_norm(x, self.stored_mean, self.stored_var, None,
                                None, self.training, 0.1, self.eps)
         elif self.norm_style == 'in':
             out = F.instance_norm(x, self.stored_mean, self.stored_var,
                                   None, None, self.training, 0.1, self.eps)
         elif self.norm_style == 'gn':
             out = groupnorm(x, self.normstyle)
         elif self.norm_style == 'nonorm':
             out = x
         return out * gain + bias
Exemplo n.º 3
0
    def forward(self, input, gain=None, bias=None):
        # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
        if not (self._is_parallel and self.training):
            out = F.batch_norm(input, self.running_mean, self.running_var,
                               self.weight, self.bias, self.training,
                               self.momentum, self.eps)
            if gain is not None:
                out = out + gain
            if bias is not None:
                out = out + bias
            return out

        # Resize the input to (B, C, -1).
        input_shape = input.size()
        # print(input_shape)
        input = input.view(input.size(0), input.size(1), -1)

        # Compute the sum and square-sum.
        sum_size = input.size(0) * input.size(2)
        input_sum = _sum_ft(input)
        input_ssum = _sum_ft(input**2)
        # Reduce-and-broadcast the statistics.
        # print('it begins')
        if self._parallel_id == 0:
            mean, inv_std = self._sync_master.run_master(
                _ChildMessage(input_sum, input_ssum, sum_size))
        else:
            mean, inv_std = self._slave_pipe.run_slave(
                _ChildMessage(input_sum, input_ssum, sum_size))
        # if self._parallel_id == 0:
        # # print('here')
        # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
        # else:
        # # print('there')
        # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))

        # print('how2')
        # num = sum_size
        # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum()), float(ssum.sum()), float(sum_size), float(input_sum.sum())))
        # Fix the graph
        # sum = (sum.detach() - input_sum.detach()) + input_sum
        # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum

        # mean = sum / num
        # var = ssum / num - mean ** 2
        # # var = (ssum - mean * sum) / num
        # inv_std = torch.rsqrt(var + self.eps)

        # Compute the output.
        if gain is not None:
            # print('gaining')
            # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1)
            # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1)
            # output = input * scale - shift
            output = (input - _unsqueeze_ft(mean)) * (
                _unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1)
        elif self.affine:
            # MJY:: Fuse the multiplication for speed.
            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(
                inv_std * self.weight) + _unsqueeze_ft(self.bias)
        else:
            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)

        # Reshape it.
        return output.view(input_shape)