コード例 #1
0
 def count_flops(self, h, w):
     # Count conv FLOPs based on input HW
     expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w)
     # If block is strided we decrease resolution here.
     dw_flops = base.count_conv_flops(self.width, self.conv1, h, w)
     if self.stride > 1:
         h, w = h / self.stride, w / self.stride
     if self.use_projection:
         sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h,
                                          w)
     else:
         sc_flops = 0
     contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
     return sum([expand_flops, dw_flops, contract_flops, sc_flops])
コード例 #2
0
 def count_flops(self, h, w):
     flops = []
     flops += [base.count_conv_flops(3, self.initial_conv, h, w)]
     h, w = h / 2, w / 2
     # Body FLOPs
     for block in self.blocks:
         flops += [block.count_flops(h, w)]
         if block.stride > 1:
             h, w = h / block.stride, w / block.stride
     # Head module FLOPs
     out_ch = self.blocks[-1].out_ch
     flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
     # Count flops for classifier
     flops += [self.final_conv.output_channels * self.fc.output_size]
     return flops, sum(flops)
コード例 #3
0
 def count_flops(self, h, w):
   # Count conv FLOPs based on input HW
   expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w)
   # If block is strided we decrease resolution here.
   dw_flops = base.count_conv_flops(self.width, self.conv1, h, w)
   if self.stride > 1:
     h, w = h / self.stride, w / self.stride
   if self.use_projection:
     sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h, w)
   else:
     sc_flops = 0
   # SE flops happen on avg-pooled activations
   se_flops = self.se.fc0.output_size * self.width
   se_flops += self.se.fc0.output_size * self.se.fc1.output_size
   contract_flops = base.count_conv_flops(self.width, self.conv2, h, w)
   return sum([expand_flops, dw_flops, se_flops, contract_flops, sc_flops])
コード例 #4
0
ファイル: nfnet.py プロジェクト: isseebx123/deepmind-research
 def count_flops(self, h, w):
     flops = []
     ch = 3
     for module in self.stem.layers:
         if isinstance(module, hk.Conv2D):
             flops += [base.count_conv_flops(ch, module, h, w)]
             if any([item > 1 for item in module.stride]):
                 h, w = h / module.stride[0], w / module.stride[1]
             ch = module.output_channels
     # Body FLOPs
     for block in self.blocks:
         flops += [block.count_flops(h, w)]
         if block.stride > 1:
             h, w = h / block.stride, w / block.stride
     # Head module FLOPs
     out_ch = self.blocks[-1].out_ch
     flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)]
     # Count flops for classifier
     flops += [self.final_conv.output_channels * self.fc.output_size]
     return flops, sum(flops)