def testSuccessResnetV2(self): build_resnet(resnet_v2.resnet_v2_block, resnet_v2.resnet_v2) mapper = gamma_mapper.ConvGammaMapperByConnectivity() # Check all "regular" convs, that are connected to their own batch norm, # without residual connecitons involved. for block in (1, 2): for unit in (1, 2): for conv in (1, 2): self.assertGammaMatchesConv( mapper, 'resnet_v2/block%d/unit_%d/bottleneck_v2/conv%d' % (block, unit, conv)) # This diagram depicts all the convs and the batch-norm that don't have a # one to one mapping: # # CONVS BATCH-NORMS # # block1/unit_1/shortcut --+ # | # block1/unit_1/conv3 ----+--> block1/unit_2/preact # | # block1/unit_2/conv3 ----+--> block2/unit_1/preact # # # block2/unit_1/shortcut --+ # | # block2/unit_1/conv3 ----+--> block2/unit_1/preact # | # block2/unit_2/conv3 ----+--> postnorm # # This connectivity is tested below. self.assertConvsConnectedToGammas([ 'resnet_v2/block1/unit_1/bottleneck_v2/shortcut/Conv2D', 'resnet_v2/block1/unit_1/bottleneck_v2/conv3/Conv2D' ], [ 'resnet_v2/block1/unit_2/bottleneck_v2/preact/gamma', 'resnet_v2/block2/unit_1/bottleneck_v2/preact/gamma' ], mapper) self.assertConvsConnectedToGammas([ 'resnet_v2/block1/unit_2/bottleneck_v2/conv3/Conv2D', ], [ 'resnet_v2/block2/unit_1/bottleneck_v2/preact/gamma', ], mapper) self.assertConvsConnectedToGammas([ 'resnet_v2/block2/unit_1/bottleneck_v2/shortcut/Conv2D', 'resnet_v2/block2/unit_1/bottleneck_v2/conv3/Conv2D' ], [ 'resnet_v2/block2/unit_2/bottleneck_v2/preact/gamma', 'resnet_v2/postnorm/gamma' ], mapper) self.assertConvsConnectedToGammas([ 'resnet_v2/block2/unit_2/bottleneck_v2/conv3/Conv2D', ], [ 'resnet_v2/postnorm/gamma', ], mapper)
def testSuccessResnetV1(self): build_resnet(resnet_v1.resnet_v1_block, resnet_v1.resnet_v1) mapper = gamma_mapper.ConvGammaMapperByConnectivity() # Here the mapping between convolutions and batch-norms is simple one to # one. for block in (1, 2): self.assertGammaMatchesConv( mapper, 'resnet_v1/block%d/unit_1/bottleneck_v1/shortcut' % block) for unit in (1, 2): for conv in (1, 2, 3): self.assertGammaMatchesConv( mapper, 'resnet_v1/block%d/unit_%d/bottleneck_v1/conv%d' % (block, unit, conv))
def createMapper(self, connectivity): if connectivity: return gamma_mapper.ConvGammaMapperByConnectivity() return gamma_mapper.ConvGammaMapperByName()