Пример #1
0
    def extract_features(self, inputs):
        im_centered = self.center_inputs(inputs)
        net_type = self.cfg['net_type']
        if 'resnet' in net_type:
            net_fun = net_funcs[net_type]
            with slim.arg_scope(resnet_v1.resnet_arg_scope()):
                net, end_points = net_fun(im_centered,
                                          global_pool=False,
                                          output_stride=16,
                                          is_training=False)
        elif 'mobilenet' in net_type:
            net_fun = net_funcs[net_type]
            with slim.arg_scope(mobilenet_v2.training_scope()):
                net, end_points = net_fun(im_centered)
        elif 'efficientnet' in net_type:
            if 'use_batch_norm' not in self.cfg.keys():
                self.cfg['use_batch_norm'] = False
            if 'use_drop_out' not in self.cfg.keys():
                self.cfg['use_drop_out'] = False

            im_centered /= tf.constant(eff.STDDEV_RGB, shape=[1, 1, 3])
            net, end_points = eff.build_model_base(
                im_centered,
                net_type,
                use_batch_norm=self.cfg['use_batch_norm'],
                drop_out=self.cfg['use_drop_out'])
        else:
            raise ValueError(f"Unknown network of type {net_type}")
        return net, end_points
Пример #2
0
 def extract_features(self,
                      inputs,
                      use_batch_norm=False,
                      use_drop_out=False):
     im_centered = self.center_inputs(inputs)
     im_centered /= tf.constant(eff.STDDEV_RGB, shape=[1, 1, 3])
     with tf.compat.v1.variable_scope("efficientnet"):
         eff_net_type = self.cfg['net_type'].replace('_', '-')
         net, end_points = eff.build_model_base(
             im_centered,
             eff_net_type,
             use_batch_norm=use_batch_norm,
             drop_out=use_drop_out)
     return net, end_points