def exec_iter(spixel_feat, trans_features, spixel_init, num_spixels_h,
              num_spixels_w, num_spixels, trans_dim):

    # Compute pixel-superpixel assignments
    pixel_assoc = \
        compute_assignments(spixel_feat, trans_features,
                            spixel_init, num_spixels_h,
                            num_spixels_w, num_spixels, trans_dim)
    # Compute superpixel features from pixel assignments
    spixel_feat1 = L.SpixelFeature2(trans_features,
                                    pixel_assoc,
                                    spixel_init,
                                    spixel_feature2_param =\
        dict(num_spixels_h = num_spixels_h, num_spixels_w = num_spixels_w))

    return spixel_feat1
def create_ssn_net(img_height,
                   img_width,
                   num_spixels,
                   pos_scale,
                   color_scale,
                   num_spixels_h,
                   num_spixels_w,
                   num_steps,
                   phase=None):

    n = caffe.NetSpec()

    if phase == 'TRAIN':
        n.img, n.spixel_init, n.feat_spixel_init, n.label, n.problabel = \
            L.Python(python_param = dict(module = "input_patch_data_layer", layer = "InputRead", param_str = "TRAIN_1000000_" + str(num_spixels)),
                     include = dict(phase = 0),
                     ntop = 5)
    elif phase == 'TEST':
        n.img, n.spixel_init, n.feat_spixel_init, n.label, n.problabel = \
            L.Python(python_param = dict(module = "input_patch_data_layer", layer = "InputRead", param_str = "VAL_10_" + str(num_spixels)),
                     include = dict(phase = 1),
                     ntop = 5)
    else:
        n.img = L.Input(shape=[dict(dim=[1, 3, img_height, img_width])])
        n.spixel_init = L.Input(
            shape=[dict(dim=[1, 1, img_height, img_width])])
        n.feat_spixel_init = L.Input(
            shape=[dict(dim=[1, 1, img_height, img_width])])

    n.pixel_features = L.PixelFeature(n.img,
                                      pixel_feature_param=dict(
                                          type=P.PixelFeature.POSITION_AND_RGB,
                                          pos_scale=float(pos_scale),
                                          color_scale=float(color_scale)))

    ### Transform Pixel features
    n.trans_features = cnn_module(n.pixel_features, trans_dim)

    # Initial Superpixels
    n.init_spixel_feat = L.SpixelFeature(n.trans_features, n.feat_spixel_init,
                                         spixel_feature_param =\
        dict(type = P.SpixelFeature.AVGRGB, rgb_scale = 1.0, ignore_idx_value = -10,
             ignore_feature_value = 255, max_spixels = int(num_spixels)))

    ### Iteration-1
    n.spixel_feat1 = exec_iter(n.init_spixel_feat, n.trans_features,
                               n.spixel_init, num_spixels_h, num_spixels_w,
                               num_spixels, trans_dim)

    ### Iteration-2
    n.spixel_feat2 = exec_iter(n.spixel_feat1, n.trans_features, n.spixel_init,
                               num_spixels_h, num_spixels_w, num_spixels,
                               trans_dim)

    ### Iteration-3
    n.spixel_feat3 = exec_iter(n.spixel_feat2, n.trans_features, n.spixel_init,
                               num_spixels_h, num_spixels_w, num_spixels,
                               trans_dim)

    ### Iteration-4
    n.spixel_feat4 = exec_iter(n.spixel_feat3, n.trans_features, n.spixel_init,
                               num_spixels_h, num_spixels_w, num_spixels,
                               trans_dim)

    if num_steps == 5:
        ### Iteration-5
        n.final_pixel_assoc  = \
            compute_assignments(n.spixel_feat4, n.trans_features,
                                n.spixel_init, num_spixels_h,
                                num_spixels_w, num_spixels, trans_dim)

    elif num_steps == 10:
        ### Iteration-5
        n.spixel_feat5 = exec_iter(n.spixel_feat4, n.trans_features,
                                   n.spixel_init, num_spixels_h, num_spixels_w,
                                   num_spixels, trans_dim)

        ### Iteration-6
        n.spixel_feat6 = exec_iter(n.spixel_feat5, n.trans_features,
                                   n.spixel_init, num_spixels_h, num_spixels_w,
                                   num_spixels, trans_dim)

        ### Iteration-7
        n.spixel_feat7 = exec_iter(n.spixel_feat6, n.trans_features,
                                   n.spixel_init, num_spixels_h, num_spixels_w,
                                   num_spixels, trans_dim)

        ### Iteration-8
        n.spixel_feat8 = exec_iter(n.spixel_feat7, n.trans_features,
                                   n.spixel_init, num_spixels_h, num_spixels_w,
                                   num_spixels, trans_dim)

        ### Iteration-9
        n.spixel_feat9 = exec_iter(n.spixel_feat8, n.trans_features,
                                   n.spixel_init, num_spixels_h, num_spixels_w,
                                   num_spixels, trans_dim)

        ### Iteration-10
        n.final_pixel_assoc  = \
            compute_assignments(n.spixel_feat9, n.trans_features,
                                n.spixel_init, num_spixels_h,
                                num_spixels_w, num_spixels, trans_dim)

    if phase == 'TRAIN' or phase == 'TEST':

        # Compute final spixel features
        n.new_spixel_feat = L.SpixelFeature2(n.pixel_features,
                                             n.final_pixel_assoc,
                                             n.spixel_init,
                                             spixel_feature2_param =\
            dict(num_spixels_h = num_spixels_h, num_spixels_w = num_spixels_w))

        n.new_spix_indices = compute_final_spixel_labels(
            n.final_pixel_assoc, n.spixel_init, num_spixels_h, num_spixels_w)
        n.recon_feat2 = L.Smear(n.new_spixel_feat,
                                n.new_spix_indices,
                                propagate_down=[True, False])
        n.loss1, n.loss2 = position_color_loss(n.recon_feat2,
                                               n.pixel_features,
                                               pos_weight=0.00001,
                                               col_weight=0.0)

        # Convert pixel labels to spixel labels
        n.spixel_label = L.SpixelFeature2(n.problabel,
                                          n.final_pixel_assoc,
                                          n.spixel_init,
                                          spixel_feature2_param =\
            dict(num_spixels_h = num_spixels_h, num_spixels_w = num_spixels_w))
        # Convert spixel labels back to pixel labels
        n.recon_label = decode_features(n.final_pixel_assoc,
                                        n.spixel_label,
                                        n.spixel_init,
                                        num_spixels_h,
                                        num_spixels_w,
                                        num_spixels,
                                        num_channels=50)

        n.recon_label = L.ReLU(n.recon_label, in_place=True)
        n.recon_label2 = L.Power(n.recon_label, power_param=dict(shift=1e-10))
        n.recon_label3 = normalize(n.recon_label2, 50)
        n.loss3 = L.LossWithoutSoftmax(n.recon_label3,
                                       n.label,
                                       loss_param=dict(ignore_label=255),
                                       loss_weight=1.0)

    else:
        n.new_spix_indices = compute_final_spixel_labels(
            n.final_pixel_assoc, n.spixel_init, num_spixels_h, num_spixels_w)

    return n.to_proto()
Ejemplo n.º 3
0
def create_ssn_net(img_height, img_width,
                   num_spixels, pos_scale, color_scale,
                   num_spixels_h, num_spixels_w, num_steps,
                   phase = None):

    n = caffe.NetSpec()

    if phase == 'TRAIN':
        n.img, n.spixel_init, n.feat_spixel_init, n.label, n.problabel, n.seg_label = \
            L.Python(python_param = dict(module = "input_patch_data_layer", layer = "InputRead", param_str = "TRAIN_1000000_" + str(num_spixels)),
                     include = dict(phase = 0),
                     ntop = 6)

    elif phase == 'TEST':
        n.img, n.spixel_init, n.feat_spixel_init, n.label, n.problabel, n.seg_label= \
            L.Python(python_param = dict(module = "input_patch_data_layer", layer = "InputRead", param_str = "VAL_10_" + str(num_spixels)),
                     include = dict(phase = 1),
                     ntop = 6)
    else:

        n.img = L.Input(shape=[dict(dim=[1, 3, img_height, img_width])])
        n.spixel_init = L.Input(shape=[dict(dim=[1, 1, img_height, img_width])])
        n.feat_spixel_init = L.Input(shape=[dict(dim=[1, 1, img_height, img_width])])
        n.bound_param = L.Input(shape=[dict(dim=[1, 1, 1, 1])])
        n.minsize_param = L.Input(shape=[dict(dim=[1, 1, 1, 1])])

    # 我也不知道这里怎么得出pixel_features
    # lib/video_prop_networks/lib/caffe/src/caffe/layers
    n.pixel_features = L.PixelFeature(n.img,
                                      pixel_feature_param = dict(type = P.PixelFeature.POSITION_AND_RGB,
                                                                 pos_scale = float(pos_scale),
                                                                 color_scale = float(color_scale)))

    ### Transform Pixel features trans_dim = 15
    n.trans_features, n.conv_dsp = cnn_module(n.pixel_features, trans_dim)

    # Initial Superpixels
    n.init_spixel_feat = L.SpixelFeature(n.trans_features, n.feat_spixel_init,
                                         spixel_feature_param =\
        dict(type = P.SpixelFeature.AVGRGB, rgb_scale = 1.0, ignore_idx_value = -10,
             ignore_feature_value = 255, max_spixels = int(num_spixels)))

    ### Iteration-1
    n.spixel_feat1 = exec_iter(n.init_spixel_feat, n.trans_features,
                               n.spixel_init, num_spixels_h,
                               num_spixels_w, num_spixels, trans_dim)

    ### Iteration-2
    n.spixel_feat2 = exec_iter(n.spixel_feat1, n.trans_features,
                               n.spixel_init, num_spixels_h,
                               num_spixels_w, num_spixels, trans_dim)

    ### Iteration-3
    n.spixel_feat3 = exec_iter(n.spixel_feat2, n.trans_features,
                               n.spixel_init, num_spixels_h,
                               num_spixels_w, num_spixels, trans_dim)

    ### Iteration-4
    n.spixel_feat4 = exec_iter(n.spixel_feat3, n.trans_features,
                               n.spixel_init, num_spixels_h,
                               num_spixels_w, num_spixels, trans_dim)

    if num_steps == 5:
        ### Iteration-5
        n.final_pixel_assoc  = \
            compute_assignments(n.spixel_feat4, n.trans_features,
                                n.spixel_init, num_spixels_h,
                                num_spixels_w, num_spixels, trans_dim)

    elif num_steps == 10:
        ### Iteration-5
        n.spixel_feat5 = exec_iter(n.spixel_feat4, n.trans_features,
                                   n.spixel_init, num_spixels_h,
                                   num_spixels_w, num_spixels, trans_dim)

        ### Iteration-6
        n.spixel_feat6 = exec_iter(n.spixel_feat5, n.trans_features,
                                   n.spixel_init, num_spixels_h,
                                   num_spixels_w, num_spixels, trans_dim)

        ### Iteration-7
        n.spixel_feat7 = exec_iter(n.spixel_feat6, n.trans_features,
                                   n.spixel_init, num_spixels_h,
                                   num_spixels_w, num_spixels, trans_dim)

        ### Iteration-8
        n.spixel_feat8 = exec_iter(n.spixel_feat7, n.trans_features,
                                   n.spixel_init, num_spixels_h,
                                   num_spixels_w, num_spixels, trans_dim)

        ### Iteration-9
        n.spixel_feat9 = exec_iter(n.spixel_feat8, n.trans_features,
                                   n.spixel_init, num_spixels_h,
                                   num_spixels_w, num_spixels, trans_dim)

        ### Iteration-10
        # 得到超像素与像素之间的软链接
        n.final_pixel_assoc  = \
            compute_assignments(n.spixel_feat9, n.trans_features,
                                n.spixel_init, num_spixels_h,
                                num_spixels_w, num_spixels, trans_dim)


    if phase == 'TRAIN' or phase == 'TEST':

        # Compute final spixel features
        # 紧凑型损失
        n.new_spixel_feat = L.SpixelFeature2(n.pixel_features,
                                             n.final_pixel_assoc,
                                             n.spixel_init,
                                             spixel_feature2_param =\
            dict(num_spixels_h = num_spixels_h, num_spixels_w = num_spixels_w))


        # 得到最后的超像素标签
        #计算最后的超像素与像素的联系
        n.new_spix_indices = compute_final_spixel_labels(n.final_pixel_assoc,
                                                         n.spixel_init,
                                                         num_spixels_h, num_spixels_w)

        # superpixel_pooling
        n.superpixel_pooling_out, n.superpixel_seg_label = L.SuperpixelPooling(n.conv_dsp, n.seg_label,
                                                                               n.new_spix_indices,
                                                                               superpixel_pooling_param=dict(
                                                                                   pool_type=P.Pooling.AVE), ntop=2)

        n.loss0 = L.SimilarityLoss(n.superpixel_pooling_out, n.superpixel_seg_label, n.new_spix_indices,
                                      loss_weight=0.1, similarity_loss_param=dict(sample_points=1))


        n.recon_feat2 = L.Smear(n.new_spixel_feat, n.new_spix_indices,
                                propagate_down = [True, False])
        n.loss1, n.loss2 = position_color_loss(n.recon_feat2, n.pixel_features,
                                               pos_weight = 0.00001,
                                               col_weight = 0.0)


        # Convert pixel labels to spixel labels
        # 任务特征重建损失
        # 将像素标签转化为超像素标签(这里应该是硬链接,用来计算损失函数)
        # 个人感觉spixel_label和上面的
        n.spixel_label = L.SpixelFeature2(n.problabel,
                                          n.final_pixel_assoc,
                                          n.spixel_init,
                                          spixel_feature2_param =\
            dict(num_spixels_h = num_spixels_h, num_spixels_w = num_spixels_w))
        # Convert spixel labels back to pixel labels
        # 将超像素标签转回到像素标签
        n.recon_label = decode_features(n.final_pixel_assoc, n.spixel_label, n.spixel_init,
                                        num_spixels_h, num_spixels_w, num_spixels, num_channels = 50)

        n.recon_label = L.ReLU(n.recon_label, in_place = True)
        n.recon_label2 = L.Power(n.recon_label, power_param = dict(shift = 1e-10))
        n.recon_label3 = normalize(n.recon_label2, 50)
        n.loss3 = L.LossWithoutSoftmax(n.recon_label3, n.label,
                                       loss_param = dict(ignore_label = 255),
                                       loss_weight = 10)

    else:
        n.new_spix_indices = compute_final_spixel_labels(n.final_pixel_assoc,
                                                         n.spixel_init,
                                                         num_spixels_h, num_spixels_w)
        n.segmentation = L.EgbSegment(n.conv_dsp, n.new_spix_indices, n.bound_param, n.minsize_param,
                                      egb_segment_param=dict(bound=3, min_size=10))
#  NetSpec 是包含Tops(可以直接赋值作为属性)的集合。调用 NetSpec.to_proto 创建包含所有层(layers)的网络参数,这些层(layers)需要被赋值,并使用被赋值的名字。
    return n.to_proto()