def count_flops(self, h, w): # Count conv FLOPs based on input HW expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w) # If block is strided we decrease resolution here. dw_flops = base.count_conv_flops(self.width, self.conv1, h, w) if self.stride > 1: h, w = h / self.stride, w / self.stride if self.use_projection: sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h, w) else: sc_flops = 0 contract_flops = base.count_conv_flops(self.width, self.conv2, h, w) return sum([expand_flops, dw_flops, contract_flops, sc_flops])
def count_flops(self, h, w): flops = [] flops += [base.count_conv_flops(3, self.initial_conv, h, w)] h, w = h / 2, w / 2 # Body FLOPs for block in self.blocks: flops += [block.count_flops(h, w)] if block.stride > 1: h, w = h / block.stride, w / block.stride # Head module FLOPs out_ch = self.blocks[-1].out_ch flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)] # Count flops for classifier flops += [self.final_conv.output_channels * self.fc.output_size] return flops, sum(flops)
def count_flops(self, h, w): # Count conv FLOPs based on input HW expand_flops = base.count_conv_flops(self.in_ch, self.conv0, h, w) # If block is strided we decrease resolution here. dw_flops = base.count_conv_flops(self.width, self.conv1, h, w) if self.stride > 1: h, w = h / self.stride, w / self.stride if self.use_projection: sc_flops = base.count_conv_flops(self.in_ch, self.conv_shortcut, h, w) else: sc_flops = 0 # SE flops happen on avg-pooled activations se_flops = self.se.fc0.output_size * self.width se_flops += self.se.fc0.output_size * self.se.fc1.output_size contract_flops = base.count_conv_flops(self.width, self.conv2, h, w) return sum([expand_flops, dw_flops, se_flops, contract_flops, sc_flops])
def count_flops(self, h, w): flops = [] ch = 3 for module in self.stem.layers: if isinstance(module, hk.Conv2D): flops += [base.count_conv_flops(ch, module, h, w)] if any([item > 1 for item in module.stride]): h, w = h / module.stride[0], w / module.stride[1] ch = module.output_channels # Body FLOPs for block in self.blocks: flops += [block.count_flops(h, w)] if block.stride > 1: h, w = h / block.stride, w / block.stride # Head module FLOPs out_ch = self.blocks[-1].out_ch flops += [base.count_conv_flops(out_ch, self.final_conv, h, w)] # Count flops for classifier flops += [self.final_conv.output_channels * self.fc.output_size] return flops, sum(flops)