Ejemplo n.º 1
0
    def inference (self, net_input, num_classes, is_training):
        if FLAGS.patch_slim:
            fuck_slim.patch(is_training)
        network = None
        init_fn = None
        if FLAGS.net == "FC-DenseNet56" or FLAGS.net == "FC-DenseNet67" or FLAGS.net == "FC-DenseNet103":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
                network = build_fc_densenet(net_input, preset_model = FLAGS.net, num_classes=num_classes)
        elif FLAGS.net == "RefineNet-Res50" or FLAGS.net == "RefineNet-Res101" or FLAGS.net == "RefineNet-Res152":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
            # RefineNet requires pre-trained ResNet weights
                network, init_fn = build_refinenet(net_input, preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training)
        elif FLAGS.net == "FRRN-A" or FLAGS.net == "FRRN-B":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
                network = build_frrn(net_input, preset_model = FLAGS.net, num_classes=num_classes)
        elif FLAGS.net == "Encoder-Decoder" or FLAGS.net == "Encoder-Decoder-Skip":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
                network = build_encoder_decoder(net_input, preset_model = FLAGS.net, num_classes=num_classes)
        elif FLAGS.net == "MobileUNet" or FLAGS.net == "MobileUNet-Skip":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
                network = build_mobile_unet(net_input, preset_model = FLAGS.net, num_classes=num_classes)
        elif FLAGS.net == "PSPNet-Res50" or FLAGS.net == "PSPNet-Res101" or FLAGS.net == "PSPNet-Res152":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
            # Image size is required for PSPNet
            # PSPNet requires pre-trained ResNet weights
                network, init_fn = build_pspnet(net_input, label_size=[args.crop_height, args.crop_width], preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training)
        elif FLAGS.net == "GCN-Res50" or FLAGS.net == "GCN-Res101" or FLAGS.net == "GCN-Res152":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
            # GCN requires pre-trained ResNet weights
                network, init_fn = build_gcn(net_input, preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training)
        elif FLAGS.net == "DeepLabV3-Res50" or FLAGS.net == "DeepLabV3-Res101" or FLAGS.net == "DeepLabV3-Res152":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
            # DeepLabV requires pre-trained ResNet weights
                network, init_fn = build_deeplabv3(net_input, preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training)
        elif FLAGS.net == "DeepLabV3_plus-Res50" or FLAGS.net == "DeepLabV3_plus-Res101" or FLAGS.net == "DeepLabV3_plus-Res152":
            # DeepLabV3+ requires pre-trained ResNet weights
            with slim.arg_scope(aardvark.default_argscope(is_training)):
                network, init_fn = build_deeplabv3_plus(net_input, preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training)
        elif FLAGS.net == "AdapNet":
            with slim.arg_scope(aardvark.default_argscope(is_training)):
                network = build_adaptnet(net_input, num_classes=num_classes)
        else:
            raise ValueError("Error: the model %d is not available. Try checking which models are available using the command python main.py --help")

        self.init_fn = init_fn
        return network
Ejemplo n.º 2
0
 def rpn_parameters(self, channels, stride):
     upscale = self.backbone_stride // stride
     with slim.arg_scope(aardvark.default_argscope(self.is_training)):
         return slim.conv2d_transpose(self.backbone,
                                      channels,
                                      2 * upscale,
                                      upscale,
                                      activation_fn=None)
     pass
Ejemplo n.º 3
0
    def inference (self, images, classes, is_training):
        assert FLAGS.clip_stride % FLAGS.backbone_stride == 0

        backbone = aardvark.create_stock_slim_network(FLAGS.backbone, images, is_training, global_pool=False, stride=FLAGS.backbone_stride)

        if FLAGS.finetune:
            backbone = tf.stop_gradient(backbone)

        with slim.arg_scope(aardvark.default_argscope(self.is_training)):
            if FLAGS.multistep > 0:
                if FLAGS.multistep == 1:
                    aardvark.print_red("multistep = 1 doesn't converge well")
                net = slim_multistep_upscale(backbone, FLAGS.backbone_stride, FLAGS.reduction, FLAGS.multistep)
                logits = slim.conv2d(net, classes, 3, 1, activation_fn=None, padding='SAME')
            else:
                logits = slim.conv2d_transpose(backbone, classes, FLAGS.backbone_stride * 2, FLAGS.backbone_stride, activation_fn=None, padding='SAME')

        if FLAGS.finetune:
            assert FLAGS.colorspace == 'RGB'
            def is_trainable (x):
                return not x.startswith(FLAGS.backbone)
            self.init_session, self.variables_to_train = aardvark.setup_finetune(FLAGS.finetune, is_trainable)
        return logits
Ejemplo n.º 4
0
    def build_graph(self):
        if True:  # setup inputs
            # parameters
            is_training = tf.placeholder(tf.bool, name="is_training")
            images = tf.placeholder(tf.float32,
                                    shape=(None, None, None, FLAGS.channels),
                                    name="images")
            # the reset are for training only
            mask = tf.placeholder(tf.float32,
                                  shape=(None, None, None, FLAGS.classes))
            gt_offsets = tf.placeholder(tf.float32,
                                        shape=(None, None, None,
                                               FLAGS.classes * 2))

            self.is_training = is_training
            self.images = images
            self.mask = mask
            self.gt_offsets = gt_offsets

        backbone = aardvark.create_stock_slim_network(
            FLAGS.backbone,
            images,
            is_training,
            global_pool=False,
            stride=FLAGS.backbone_stride)

        with tf.variable_scope('head'), slim.arg_scope(
                aardvark.default_argscope(is_training)):

            if FLAGS.finetune:
                backbone = tf.stop_gradient(backbone)
            #net = slim_multistep_upscale(net, FLAGS.backbone_stride / FLAGS.stride, FLAGS.reduction)
            #backbone = net
            stride = FLAGS.backbone_stride // FLAGS.stride
            #backbone = slim.conv2d_transpose(backbone, FLAGS.feature_channels, st*2, st)

            #prob = slim.conv2d(backbone, FLAGS.classes, 3, 1, activation_fn=tf.sigmoid)
            prob = slim.conv2d_transpose(backbone,
                                         FLAGS.classes,
                                         stride * 2,
                                         stride,
                                         activation_fn=tf.sigmoid)

            #logits2 = tf.reshape(logits, (-1, 2))
            #prob2 = tf.squeeze(tf.slice(tf.nn.softmax(logits2), [0, 1], [-1, 1]), 1)
            #tf.reshape(prob2, tf.shape(mask), name='prob')
            #xe = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits2, labels=mask)
            dice = tf.identity(dice_loss(mask, prob), name='di')

            tf.losses.add_loss(dice)
            self.metrics.append(dice)

            offsets = slim.conv2d_transpose(backbone,
                                            FLAGS.classes * 2,
                                            stride * 2,
                                            stride,
                                            activation_fn=None)
            offsets2 = tf.reshape(offsets, (-1, 2))  # ? * 4
            gt_offsets2 = tf.reshape(gt_offsets, (-1, 2))
            mask2 = tf.reshape(mask, (-1, ))

            pl = params_loss(offsets2, gt_offsets2) * mask2
            pl = tf.reduce_sum(pl) / (tf.reduce_sum(mask2) + 1)
            pl = tf.check_numerics(pl * FLAGS.offset_weight, 'pl',
                                   name='p1')  # params-loss

            tf.losses.add_loss(pl)
            self.metrics.append(pl)
        tf.identity(prob, name='prob')
        tf.identity(offsets, 'offsets')

        if FLAGS.finetune:
            assert FLAGS.colorspace == 'RGB'

            def is_trainable(x):
                return x.startswith('head')

            self.init_session, self.variables_to_train = aardvark.setup_finetune(
                FLAGS.finetune, is_trainable)
        pass