Esempio n. 1
0
    def _decode_block_string(self, block_string):
        """Gets a block through a string notation of arguments."""
        if six.PY2:
            assert isinstance(block_string, (str, unicode))
        else:
            assert isinstance(block_string, str)
        ops = block_string.split('_')
        options = {}
        for op in ops:
            splits = re.split(r'(\d.*)', op)
            if len(splits) >= 2:
                key, value = splits[:2]
                options[key] = value

        if 's' not in options or len(options['s']) != 2:
            raise ValueError('Strides options should be a pair of integers.')

        return efficientnet_model.BlockArgs(
            kernel_size=int(options['k']),
            num_repeat=int(options['r']),
            input_filters=int(options['i']),
            output_filters=int(options['o']),
            expand_ratio=int(options['e']),
            id_skip=('noskip' not in block_string),
            se_ratio=float(options['se']) if 'se' in options else None,
            strides=[int(options['s'][0]),
                     int(options['s'][1])],
            conv_type=int(options['c']) if 'c' in options else 0,
            fused_conv=int(options['f']) if 'f' in options else 0,
            super_pixel=int(options['p']) if 'p' in options else 0,
            condconv=('cc' in block_string))
Esempio n. 2
0
 def test_reduction_endpoint_with_single_block_without_sp(self):
   """Test reduction point with single block/layer."""
   images = tf.zeros((10, 128, 128, 3), dtype=tf.float32)
   global_params = efficientnet_model.GlobalParams(
       1.0,
       1.0,
       0,
       'channels_last',
       num_classes=10,
       batch_norm=utils.TpuBatchNormalization)
   blocks_args = [
       efficientnet_model.BlockArgs(
           kernel_size=3,
           num_repeat=1,
           input_filters=3,
           output_filters=6,
           expand_ratio=6,
           id_skip=False,
           strides=[2, 2],
           se_ratio=0.8,
           conv_type=0,
           fused_conv=0,
           super_pixel=0)
   ]
   model = efficientnet_model.Model(blocks_args, global_params)
   _ = model(images, training=True)
   self.assertIn('reduction_1', model.endpoints)
   # single block should have one and only one reduction endpoint
   self.assertNotIn('reduction_2', model.endpoints)
Esempio n. 3
0
 def test_bottleneck_block_with_superpixel_layer(self):
   """Test for creating a model with fused bottleneck block arguments."""
   images = tf.zeros((10, 128, 128, 3), dtype=tf.float32)
   global_params = efficientnet_model.GlobalParams(
       1.0,
       1.0,
       0,
       'channels_last',
       num_classes=10,
       batch_norm=utils.TpuBatchNormalization)
   blocks_args = [
       efficientnet_model.BlockArgs(
           kernel_size=3,
           num_repeat=3,
           input_filters=3,
           output_filters=6,
           expand_ratio=6,
           id_skip=True,
           strides=[2, 2],
           conv_type=0,
           fused_conv=0,
           super_pixel=1)
   ]
   model = efficientnet_model.Model(blocks_args, global_params)
   outputs = model(images, training=True)
   self.assertEqual((10, 10), outputs.shape)
Esempio n. 4
0
 def test_variables(self):
   """Test for variables in blocks to be included in `model.variables`."""
   images = tf.zeros((10, 128, 128, 3), dtype=tf.float32)
   global_params = efficientnet_model.GlobalParams(
       1.0,
       1.0,
       0,
       'channels_last',
       num_classes=10,
       batch_norm=utils.TpuBatchNormalization)
   blocks_args = [
       efficientnet_model.BlockArgs(
           kernel_size=3,
           num_repeat=3,
           input_filters=3,
           output_filters=6,
           expand_ratio=6,
           id_skip=False,
           strides=[2, 2],
           se_ratio=0.8,
           conv_type=0,
           fused_conv=0,
           super_pixel=0)
   ]
   model = efficientnet_model.Model(blocks_args, global_params)
   _ = model(images, training=True)
   var_names = {var.name for var in model.variables}
   self.assertIn('model/blocks_0/conv2d/kernel:0', var_names)