def inference (self, net_input, num_classes, is_training): if FLAGS.patch_slim: fuck_slim.patch(is_training) network = None init_fn = None if FLAGS.net == "FC-DenseNet56" or FLAGS.net == "FC-DenseNet67" or FLAGS.net == "FC-DenseNet103": with slim.arg_scope(aardvark.default_argscope(is_training)): network = build_fc_densenet(net_input, preset_model = FLAGS.net, num_classes=num_classes) elif FLAGS.net == "RefineNet-Res50" or FLAGS.net == "RefineNet-Res101" or FLAGS.net == "RefineNet-Res152": with slim.arg_scope(aardvark.default_argscope(is_training)): # RefineNet requires pre-trained ResNet weights network, init_fn = build_refinenet(net_input, preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training) elif FLAGS.net == "FRRN-A" or FLAGS.net == "FRRN-B": with slim.arg_scope(aardvark.default_argscope(is_training)): network = build_frrn(net_input, preset_model = FLAGS.net, num_classes=num_classes) elif FLAGS.net == "Encoder-Decoder" or FLAGS.net == "Encoder-Decoder-Skip": with slim.arg_scope(aardvark.default_argscope(is_training)): network = build_encoder_decoder(net_input, preset_model = FLAGS.net, num_classes=num_classes) elif FLAGS.net == "MobileUNet" or FLAGS.net == "MobileUNet-Skip": with slim.arg_scope(aardvark.default_argscope(is_training)): network = build_mobile_unet(net_input, preset_model = FLAGS.net, num_classes=num_classes) elif FLAGS.net == "PSPNet-Res50" or FLAGS.net == "PSPNet-Res101" or FLAGS.net == "PSPNet-Res152": with slim.arg_scope(aardvark.default_argscope(is_training)): # Image size is required for PSPNet # PSPNet requires pre-trained ResNet weights network, init_fn = build_pspnet(net_input, label_size=[args.crop_height, args.crop_width], preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training) elif FLAGS.net == "GCN-Res50" or FLAGS.net == "GCN-Res101" or FLAGS.net == "GCN-Res152": with slim.arg_scope(aardvark.default_argscope(is_training)): # GCN requires pre-trained ResNet weights network, init_fn = build_gcn(net_input, preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training) elif FLAGS.net == "DeepLabV3-Res50" or FLAGS.net == "DeepLabV3-Res101" or FLAGS.net == "DeepLabV3-Res152": with slim.arg_scope(aardvark.default_argscope(is_training)): # DeepLabV requires pre-trained ResNet weights network, init_fn = build_deeplabv3(net_input, preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training) elif FLAGS.net == "DeepLabV3_plus-Res50" or FLAGS.net == "DeepLabV3_plus-Res101" or FLAGS.net == "DeepLabV3_plus-Res152": # DeepLabV3+ requires pre-trained ResNet weights with slim.arg_scope(aardvark.default_argscope(is_training)): network, init_fn = build_deeplabv3_plus(net_input, preset_model = FLAGS.net, num_classes=num_classes, is_training=is_training) elif FLAGS.net == "AdapNet": with slim.arg_scope(aardvark.default_argscope(is_training)): network = build_adaptnet(net_input, num_classes=num_classes) else: raise ValueError("Error: the model %d is not available. Try checking which models are available using the command python main.py --help") self.init_fn = init_fn return network
def rpn_parameters(self, channels, stride): upscale = self.backbone_stride // stride with slim.arg_scope(aardvark.default_argscope(self.is_training)): return slim.conv2d_transpose(self.backbone, channels, 2 * upscale, upscale, activation_fn=None) pass
def inference (self, images, classes, is_training): assert FLAGS.clip_stride % FLAGS.backbone_stride == 0 backbone = aardvark.create_stock_slim_network(FLAGS.backbone, images, is_training, global_pool=False, stride=FLAGS.backbone_stride) if FLAGS.finetune: backbone = tf.stop_gradient(backbone) with slim.arg_scope(aardvark.default_argscope(self.is_training)): if FLAGS.multistep > 0: if FLAGS.multistep == 1: aardvark.print_red("multistep = 1 doesn't converge well") net = slim_multistep_upscale(backbone, FLAGS.backbone_stride, FLAGS.reduction, FLAGS.multistep) logits = slim.conv2d(net, classes, 3, 1, activation_fn=None, padding='SAME') else: logits = slim.conv2d_transpose(backbone, classes, FLAGS.backbone_stride * 2, FLAGS.backbone_stride, activation_fn=None, padding='SAME') if FLAGS.finetune: assert FLAGS.colorspace == 'RGB' def is_trainable (x): return not x.startswith(FLAGS.backbone) self.init_session, self.variables_to_train = aardvark.setup_finetune(FLAGS.finetune, is_trainable) return logits
def build_graph(self): if True: # setup inputs # parameters is_training = tf.placeholder(tf.bool, name="is_training") images = tf.placeholder(tf.float32, shape=(None, None, None, FLAGS.channels), name="images") # the reset are for training only mask = tf.placeholder(tf.float32, shape=(None, None, None, FLAGS.classes)) gt_offsets = tf.placeholder(tf.float32, shape=(None, None, None, FLAGS.classes * 2)) self.is_training = is_training self.images = images self.mask = mask self.gt_offsets = gt_offsets backbone = aardvark.create_stock_slim_network( FLAGS.backbone, images, is_training, global_pool=False, stride=FLAGS.backbone_stride) with tf.variable_scope('head'), slim.arg_scope( aardvark.default_argscope(is_training)): if FLAGS.finetune: backbone = tf.stop_gradient(backbone) #net = slim_multistep_upscale(net, FLAGS.backbone_stride / FLAGS.stride, FLAGS.reduction) #backbone = net stride = FLAGS.backbone_stride // FLAGS.stride #backbone = slim.conv2d_transpose(backbone, FLAGS.feature_channels, st*2, st) #prob = slim.conv2d(backbone, FLAGS.classes, 3, 1, activation_fn=tf.sigmoid) prob = slim.conv2d_transpose(backbone, FLAGS.classes, stride * 2, stride, activation_fn=tf.sigmoid) #logits2 = tf.reshape(logits, (-1, 2)) #prob2 = tf.squeeze(tf.slice(tf.nn.softmax(logits2), [0, 1], [-1, 1]), 1) #tf.reshape(prob2, tf.shape(mask), name='prob') #xe = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits2, labels=mask) dice = tf.identity(dice_loss(mask, prob), name='di') tf.losses.add_loss(dice) self.metrics.append(dice) offsets = slim.conv2d_transpose(backbone, FLAGS.classes * 2, stride * 2, stride, activation_fn=None) offsets2 = tf.reshape(offsets, (-1, 2)) # ? * 4 gt_offsets2 = tf.reshape(gt_offsets, (-1, 2)) mask2 = tf.reshape(mask, (-1, )) pl = params_loss(offsets2, gt_offsets2) * mask2 pl = tf.reduce_sum(pl) / (tf.reduce_sum(mask2) + 1) pl = tf.check_numerics(pl * FLAGS.offset_weight, 'pl', name='p1') # params-loss tf.losses.add_loss(pl) self.metrics.append(pl) tf.identity(prob, name='prob') tf.identity(offsets, 'offsets') if FLAGS.finetune: assert FLAGS.colorspace == 'RGB' def is_trainable(x): return x.startswith('head') self.init_session, self.variables_to_train = aardvark.setup_finetune( FLAGS.finetune, is_trainable) pass