Exemplo n.º 1
0
def centroid_pos_color_loss(trans_features, computed_spixel_feat,
                            new_spix_indices, num_spixels, l_weight_pos,
                            l_weight_color):

    new_spixel_features = L.SpixelFeature(trans_features, new_spix_indices,
                                          spixel_feature_param =\
        dict(type = P.SpixelFeature.AVGRGB, rgb_scale = 1.0, ignore_idx_value = -10,
             ignore_feature_value = 255, max_spixels = int(num_spixels)), propagate_down = [True, False])

    pos_recon_feat, color_recon_feat = L.Slice(computed_spixel_feat,
                                               slice_param=dict(axis=1,
                                                                slice_point=2),
                                               ntop=2)

    pos_pix_feat, color_pix_feat = L.Slice(new_spixel_features,
                                           slice_param=dict(axis=1,
                                                            slice_point=2),
                                           ntop=2)

    pos_loss = L.EuclideanLoss(pos_recon_feat,
                               pos_pix_feat,
                               loss_weight=l_weight_pos)
    color_loss = L.EuclideanLoss(color_recon_feat,
                                 color_pix_feat,
                                 loss_weight=l_weight_color)

    return pos_loss, color_loss
Exemplo n.º 2
0
def centroid_loss(trans_features, computed_spixel_feat, new_spix_indices,
                  num_spixels, l_weight):

    new_spixel_features = L.SpixelFeature(trans_features, new_spix_indices,
                                          spixel_feature_param =\
        dict(type = P.SpixelFeature.AVGRGB, rgb_scale = 1.0, ignore_idx_value = -10,
             ignore_feature_value = 255, max_spixels = int(num_spixels)), propagate_down = [True, False])

    centroid_loss = L.EuclideanLoss(computed_spixel_feat,
                                    new_spixel_features,
                                    loss_weight=l_weight)

    return centroid_loss
def load_spixel_feature_model():

    n = caffe.NetSpec()

    n.img_features = L.Input(shape=[dict(dim=[1, 6, 480, 854])])
    n.spixel_indices = L.Input(shape=[dict(dim=[1, 1, 480, 854])])
    n.spixel_features = L.SpixelFeature(
        n.img_features,
        n.spixel_indices,
        spixel_feature_param=dict(type=P.SpixelFeature.AVGRGB,
                                  max_spixels=max_spixels,
                                  rgb_scale=1.0,
                                  ignore_feature_value=ignore_feat_value))

    # Save to temporary file and load
    f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
    f.write(str(n.to_proto()))
    f.close()
    return caffe.Net(f.name, caffe.TEST)
def create_ssn_net(img_height,
                   img_width,
                   num_spixels,
                   pos_scale,
                   color_scale,
                   num_spixels_h,
                   num_spixels_w,
                   num_steps,
                   phase=None):

    n = caffe.NetSpec()

    if phase == 'TRAIN':
        n.img, n.spixel_init, n.feat_spixel_init, n.label, n.problabel = \
            L.Python(python_param = dict(module = "input_patch_data_layer", layer = "InputRead", param_str = "TRAIN_1000000_" + str(num_spixels)),
                     include = dict(phase = 0),
                     ntop = 5)
    elif phase == 'TEST':
        n.img, n.spixel_init, n.feat_spixel_init, n.label, n.problabel = \
            L.Python(python_param = dict(module = "input_patch_data_layer", layer = "InputRead", param_str = "VAL_10_" + str(num_spixels)),
                     include = dict(phase = 1),
                     ntop = 5)
    else:
        n.img = L.Input(shape=[dict(dim=[1, 3, img_height, img_width])])
        n.spixel_init = L.Input(
            shape=[dict(dim=[1, 1, img_height, img_width])])
        n.feat_spixel_init = L.Input(
            shape=[dict(dim=[1, 1, img_height, img_width])])

    n.pixel_features = L.PixelFeature(n.img,
                                      pixel_feature_param=dict(
                                          type=P.PixelFeature.POSITION_AND_RGB,
                                          pos_scale=float(pos_scale),
                                          color_scale=float(color_scale)))

    ### Transform Pixel features
    n.trans_features = cnn_module(n.pixel_features, trans_dim)

    # Initial Superpixels
    n.init_spixel_feat = L.SpixelFeature(n.trans_features, n.feat_spixel_init,
                                         spixel_feature_param =\
        dict(type = P.SpixelFeature.AVGRGB, rgb_scale = 1.0, ignore_idx_value = -10,
             ignore_feature_value = 255, max_spixels = int(num_spixels)))

    ### Iteration-1
    n.spixel_feat1 = exec_iter(n.init_spixel_feat, n.trans_features,
                               n.spixel_init, num_spixels_h, num_spixels_w,
                               num_spixels, trans_dim)

    ### Iteration-2
    n.spixel_feat2 = exec_iter(n.spixel_feat1, n.trans_features, n.spixel_init,
                               num_spixels_h, num_spixels_w, num_spixels,
                               trans_dim)

    ### Iteration-3
    n.spixel_feat3 = exec_iter(n.spixel_feat2, n.trans_features, n.spixel_init,
                               num_spixels_h, num_spixels_w, num_spixels,
                               trans_dim)

    ### Iteration-4
    n.spixel_feat4 = exec_iter(n.spixel_feat3, n.trans_features, n.spixel_init,
                               num_spixels_h, num_spixels_w, num_spixels,
                               trans_dim)

    if num_steps == 5:
        ### Iteration-5
        n.final_pixel_assoc  = \
            compute_assignments(n.spixel_feat4, n.trans_features,
                                n.spixel_init, num_spixels_h,
                                num_spixels_w, num_spixels, trans_dim)

    elif num_steps == 10:
        ### Iteration-5
        n.spixel_feat5 = exec_iter(n.spixel_feat4, n.trans_features,
                                   n.spixel_init, num_spixels_h, num_spixels_w,
                                   num_spixels, trans_dim)

        ### Iteration-6
        n.spixel_feat6 = exec_iter(n.spixel_feat5, n.trans_features,
                                   n.spixel_init, num_spixels_h, num_spixels_w,
                                   num_spixels, trans_dim)

        ### Iteration-7
        n.spixel_feat7 = exec_iter(n.spixel_feat6, n.trans_features,
                                   n.spixel_init, num_spixels_h, num_spixels_w,
                                   num_spixels, trans_dim)

        ### Iteration-8
        n.spixel_feat8 = exec_iter(n.spixel_feat7, n.trans_features,
                                   n.spixel_init, num_spixels_h, num_spixels_w,
                                   num_spixels, trans_dim)

        ### Iteration-9
        n.spixel_feat9 = exec_iter(n.spixel_feat8, n.trans_features,
                                   n.spixel_init, num_spixels_h, num_spixels_w,
                                   num_spixels, trans_dim)

        ### Iteration-10
        n.final_pixel_assoc  = \
            compute_assignments(n.spixel_feat9, n.trans_features,
                                n.spixel_init, num_spixels_h,
                                num_spixels_w, num_spixels, trans_dim)

    if phase == 'TRAIN' or phase == 'TEST':

        # Compute final spixel features
        n.new_spixel_feat = L.SpixelFeature2(n.pixel_features,
                                             n.final_pixel_assoc,
                                             n.spixel_init,
                                             spixel_feature2_param =\
            dict(num_spixels_h = num_spixels_h, num_spixels_w = num_spixels_w))

        n.new_spix_indices = compute_final_spixel_labels(
            n.final_pixel_assoc, n.spixel_init, num_spixels_h, num_spixels_w)
        n.recon_feat2 = L.Smear(n.new_spixel_feat,
                                n.new_spix_indices,
                                propagate_down=[True, False])
        n.loss1, n.loss2 = position_color_loss(n.recon_feat2,
                                               n.pixel_features,
                                               pos_weight=0.00001,
                                               col_weight=0.0)

        # Convert pixel labels to spixel labels
        n.spixel_label = L.SpixelFeature2(n.problabel,
                                          n.final_pixel_assoc,
                                          n.spixel_init,
                                          spixel_feature2_param =\
            dict(num_spixels_h = num_spixels_h, num_spixels_w = num_spixels_w))
        # Convert spixel labels back to pixel labels
        n.recon_label = decode_features(n.final_pixel_assoc,
                                        n.spixel_label,
                                        n.spixel_init,
                                        num_spixels_h,
                                        num_spixels_w,
                                        num_spixels,
                                        num_channels=50)

        n.recon_label = L.ReLU(n.recon_label, in_place=True)
        n.recon_label2 = L.Power(n.recon_label, power_param=dict(shift=1e-10))
        n.recon_label3 = normalize(n.recon_label2, 50)
        n.loss3 = L.LossWithoutSoftmax(n.recon_label3,
                                       n.label,
                                       loss_param=dict(ignore_label=255),
                                       loss_weight=1.0)

    else:
        n.new_spix_indices = compute_final_spixel_labels(
            n.final_pixel_assoc, n.spixel_init, num_spixels_h, num_spixels_w)

    return n.to_proto()
Exemplo n.º 5
0
def create_bnn_cnn_net_fold_stage(num_input_frames,
                                  fold_id='0',
                                  stage_id='1',
                                  phase=None):

    n = caffe.NetSpec()

    if phase == 'TRAIN':
        n.img, n.padimg, n.unary, n.in_features, n.out_features, n.spixel_indices, n.scales1, n.scales2, n.unary_scales, n.label = \
            L.Python(python_param = dict(module = "input_data_layer", layer = "InputRead",
                                        param_str = "TRAIN_1000000_" + fold_id + '_' + stage_id),
                     include = dict(phase = 0),
                     ntop = 10)
    elif phase == 'TEST':
        n.img, n.padimg, n.unary, n.in_features, n.out_features, n.spixel_indices, n.scales1, n.scales2, n.unary_scales, n.label = \
            L.Python(python_param = dict(module = "input_data_layer", layer = "InputRead",
                                         param_str = "VAL_50_" + fold_id + '_' + stage_id),
                     include = dict(phase = 1),
                     ntop = 10)
    else:
        n.img = L.Input(shape=[dict(dim=[1, 3, 480, 854])])
        n.padimg = L.Input(shape=[dict(dim=[1, 3, 481, 857])])

        n.unary = L.Input(
            shape=[dict(dim=[1, 2, num_input_frames, max_spixels])])
        n.in_features = L.Input(
            shape=[dict(dim=[1, 6, num_input_frames, max_spixels])])
        n.out_features = L.Input(shape=[dict(dim=[1, 6, 1, max_spixels])])
        n.spixel_indices = L.Input(shape=[dict(dim=[1, 1, 480, 854])])
        n.scales1 = L.Input(shape=[dict(dim=[1, 6, 1, 1])])
        n.scales2 = L.Input(shape=[dict(dim=[1, 6, 1, 1])])
        n.unary_scales = L.Input(shape=[dict(dim=[1, 1, num_input_frames, 1])])

    n.flatten_scales1 = L.Flatten(n.scales1, flatten_param=dict(axis=0))
    n.flatten_scales2 = L.Flatten(n.scales2, flatten_param=dict(axis=0))
    n.flatten_unary_scales = L.Flatten(n.unary_scales,
                                       flatten_param=dict(axis=0))

    n.in_scaled_features1 = L.Scale(n.in_features,
                                    n.flatten_scales1,
                                    scale_param=dict(axis=1))
    n.out_scaled_features1 = L.Scale(n.out_features,
                                     n.flatten_scales1,
                                     scale_param=dict(axis=1))

    n.in_scaled_features2 = L.Scale(n.in_features,
                                    n.flatten_scales2,
                                    scale_param=dict(axis=1))
    n.out_scaled_features2 = L.Scale(n.out_features,
                                     n.flatten_scales2,
                                     scale_param=dict(axis=1))
    n.scaled_unary = L.Scale(n.unary,
                             n.flatten_unary_scales,
                             scale_param=dict(axis=2))

    ### Start of BNN

    # BNN - stage - 1
    n.out_seg1 = L.Permutohedral(n.scaled_unary,
                                 n.in_scaled_features1,
                                 n.out_scaled_features1,
                                 permutohedral_param=dict(
                                     num_output=32,
                                     group=1,
                                     neighborhood_size=0,
                                     bias_term=True,
                                     norm_type=P.Permutohedral.AFTER,
                                     offset_type=P.Permutohedral.NONE),
                                 filter_filler=dict(type='gaussian', std=0.01),
                                 bias_filler=dict(type='constant', value=0),
                                 param=[{
                                     'lr_mult': 1,
                                     'decay_mult': 1
                                 }, {
                                     'lr_mult': 2,
                                     'decay_mult': 0
                                 }])

    n.out_seg2 = L.Permutohedral(n.scaled_unary,
                                 n.in_scaled_features2,
                                 n.out_scaled_features2,
                                 permutohedral_param=dict(
                                     num_output=32,
                                     group=1,
                                     neighborhood_size=0,
                                     bias_term=True,
                                     norm_type=P.Permutohedral.AFTER,
                                     offset_type=P.Permutohedral.NONE),
                                 filter_filler=dict(type='gaussian', std=0.01),
                                 bias_filler=dict(type='constant', value=0),
                                 param=[{
                                     'lr_mult': 1,
                                     'decay_mult': 1
                                 }, {
                                     'lr_mult': 2,
                                     'decay_mult': 0
                                 }])

    n.concat_out_seg_1 = L.Concat(n.out_seg1,
                                  n.out_seg2,
                                  concat_param=dict(axis=1))
    n.concat_out_relu_1 = L.ReLU(n.concat_out_seg_1, in_place=True)

    # BNN - stage - 2
    n.out_seg3 = L.Permutohedral(n.concat_out_relu_1,
                                 n.out_scaled_features1,
                                 n.out_scaled_features1,
                                 permutohedral_param=dict(
                                     num_output=32,
                                     group=1,
                                     neighborhood_size=0,
                                     bias_term=True,
                                     norm_type=P.Permutohedral.AFTER,
                                     offset_type=P.Permutohedral.NONE),
                                 filter_filler=dict(type='gaussian', std=0.01),
                                 bias_filler=dict(type='constant', value=0),
                                 param=[{
                                     'lr_mult': 1,
                                     'decay_mult': 1
                                 }, {
                                     'lr_mult': 2,
                                     'decay_mult': 0
                                 }])

    n.out_seg4 = L.Permutohedral(n.concat_out_relu_1,
                                 n.out_scaled_features2,
                                 n.out_scaled_features2,
                                 permutohedral_param=dict(
                                     num_output=32,
                                     group=1,
                                     neighborhood_size=0,
                                     bias_term=True,
                                     norm_type=P.Permutohedral.AFTER,
                                     offset_type=P.Permutohedral.NONE),
                                 filter_filler=dict(type='gaussian', std=0.01),
                                 bias_filler=dict(type='constant', value=0),
                                 param=[{
                                     'lr_mult': 1,
                                     'decay_mult': 1
                                 }, {
                                     'lr_mult': 2,
                                     'decay_mult': 0
                                 }])
    n.concat_out_seg_2 = L.Concat(n.out_seg3,
                                  n.out_seg4,
                                  concat_param=dict(axis=1))
    n.concat_out_relu_2 = L.ReLU(n.concat_out_seg_2, in_place=True)

    # BNN - combination
    n.connection_out = L.Concat(n.concat_out_relu_1, n.concat_out_relu_2)
    n.spixel_out_seg = L.Convolution(n.connection_out,
                                     convolution_param=dict(
                                         num_output=2,
                                         kernel_size=1,
                                         stride=1,
                                         weight_filler=dict(type='gaussian',
                                                            std=0.01),
                                         bias_filler=dict(type='constant',
                                                          value=0)),
                                     param=[{
                                         'lr_mult': 1,
                                         'decay_mult': 1
                                     }, {
                                         'lr_mult': 2,
                                         'decay_mult': 0
                                     }])
    n.spixel_out_seg_relu = L.ReLU(n.spixel_out_seg, in_place=True)

    # Going from superpixels to pixels
    n.out_seg_bilateral = L.Smear(n.spixel_out_seg_relu, n.spixel_indices)

    ### BNN - DeepLab Combination
    n.deeplab_seg_presoftmax = deeplab(n.padimg, n.img, n.spixel_indices)
    n.deeplab_seg = L.Softmax(n.deeplab_seg_presoftmax)
    n.bnn_deeplab_connection = L.Concat(n.out_seg_bilateral, n.deeplab_seg)
    n.bnn_deeplab_seg = L.Convolution(n.bnn_deeplab_connection,
                                      convolution_param=dict(
                                          num_output=2,
                                          kernel_size=1,
                                          stride=1,
                                          weight_filler=dict(type='gaussian',
                                                             std=0.01),
                                          bias_filler=dict(type='constant',
                                                           value=0)),
                                      param=[{
                                          'lr_mult': 1,
                                          'decay_mult': 1
                                      }, {
                                          'lr_mult': 2,
                                          'decay_mult': 0
                                      }])
    n.bnn_deeplab_seg_relu = L.ReLU(n.bnn_deeplab_seg, in_place=True)

    ### Start of CNN

    # CNN - Stage 1
    n.out_seg_spatial1 = L.Convolution(n.bnn_deeplab_seg_relu,
                                       convolution_param=dict(
                                           num_output=32,
                                           kernel_size=3,
                                           stride=1,
                                           pad_h=1,
                                           pad_w=1,
                                           weight_filler=dict(type='gaussian',
                                                              std=0.01),
                                           bias_filler=dict(type='constant',
                                                            value=0)),
                                       param=[{
                                           'lr_mult': 1,
                                           'decay_mult': 1
                                       }, {
                                           'lr_mult': 2,
                                           'decay_mult': 0
                                       }])
    n.out_seg_spatial_relu1 = L.ReLU(n.out_seg_spatial1, in_place=True)

    # CNN - Stage 2
    n.out_seg_spatial2 = L.Convolution(n.out_seg_spatial_relu1,
                                       convolution_param=dict(
                                           num_output=32,
                                           kernel_size=3,
                                           stride=1,
                                           pad_h=1,
                                           pad_w=1,
                                           weight_filler=dict(type='gaussian',
                                                              std=0.01),
                                           bias_filler=dict(type='constant',
                                                            value=0)),
                                       param=[{
                                           'lr_mult': 1,
                                           'decay_mult': 1
                                       }, {
                                           'lr_mult': 2,
                                           'decay_mult': 0
                                       }])
    n.out_seg_spatial_relu2 = L.ReLU(n.out_seg_spatial2, in_place=True)

    # CNN - Stage 3
    n.out_seg_spatial = L.Convolution(n.out_seg_spatial_relu2,
                                      convolution_param=dict(
                                          num_output=2,
                                          kernel_size=3,
                                          stride=1,
                                          pad_h=1,
                                          pad_w=1,
                                          weight_filler=dict(type='gaussian',
                                                             std=0.01),
                                          bias_filler=dict(type='constant',
                                                           value=0.5)),
                                      param=[{
                                          'lr_mult': 1,
                                          'decay_mult': 1
                                      }, {
                                          'lr_mult': 2,
                                          'decay_mult': 0
                                      }])

    # Normalization
    n.out_seg = normalize(n.out_seg_spatial, 2)

    if phase == 'TRAIN' or phase == 'TEST':
        n.loss = L.LossWithoutSoftmax(n.out_seg,
                                      n.label,
                                      loss_param=dict(ignore_label=1000),
                                      loss_weight=1)
        n.accuracy = L.Accuracy(n.out_seg,
                                n.label,
                                accuracy_param=dict(ignore_label=1000))
        n.loss2 = L.SoftmaxWithLoss(n.deeplab_seg_presoftmax,
                                    n.label,
                                    loss_param=dict(ignore_label=1000),
                                    loss_weight=1)
        n.accuracy2 = L.Accuracy(n.deeplab_seg_presoftmax,
                                 n.label,
                                 accuracy_param=dict(ignore_label=1000))
    else:
        n.spixel_out_seg_2 = L.SpixelFeature(n.out_seg,
                                             n.spixel_indices,
                                             spixel_feature_param=dict(
                                                 type=P.SpixelFeature.AVGRGB,
                                                 max_spixels=12000,
                                                 rgb_scale=1.0))
        n.spixel_out_seg_final = normalize(n.spixel_out_seg_2, 2)

    return n.to_proto()
Exemplo n.º 6
0
def create_ssn_net(img_height, img_width,
                   num_spixels, pos_scale, color_scale,
                   num_spixels_h, num_spixels_w, num_steps,
                   phase = None):

    n = caffe.NetSpec()

    if phase == 'TRAIN':
        n.img, n.spixel_init, n.feat_spixel_init, n.label, n.problabel, n.seg_label = \
            L.Python(python_param = dict(module = "input_patch_data_layer", layer = "InputRead", param_str = "TRAIN_1000000_" + str(num_spixels)),
                     include = dict(phase = 0),
                     ntop = 6)

    elif phase == 'TEST':
        n.img, n.spixel_init, n.feat_spixel_init, n.label, n.problabel, n.seg_label= \
            L.Python(python_param = dict(module = "input_patch_data_layer", layer = "InputRead", param_str = "VAL_10_" + str(num_spixels)),
                     include = dict(phase = 1),
                     ntop = 6)
    else:

        n.img = L.Input(shape=[dict(dim=[1, 3, img_height, img_width])])
        n.spixel_init = L.Input(shape=[dict(dim=[1, 1, img_height, img_width])])
        n.feat_spixel_init = L.Input(shape=[dict(dim=[1, 1, img_height, img_width])])
        n.bound_param = L.Input(shape=[dict(dim=[1, 1, 1, 1])])
        n.minsize_param = L.Input(shape=[dict(dim=[1, 1, 1, 1])])

    # 我也不知道这里怎么得出pixel_features
    # lib/video_prop_networks/lib/caffe/src/caffe/layers
    n.pixel_features = L.PixelFeature(n.img,
                                      pixel_feature_param = dict(type = P.PixelFeature.POSITION_AND_RGB,
                                                                 pos_scale = float(pos_scale),
                                                                 color_scale = float(color_scale)))

    ### Transform Pixel features trans_dim = 15
    n.trans_features, n.conv_dsp = cnn_module(n.pixel_features, trans_dim)

    # Initial Superpixels
    n.init_spixel_feat = L.SpixelFeature(n.trans_features, n.feat_spixel_init,
                                         spixel_feature_param =\
        dict(type = P.SpixelFeature.AVGRGB, rgb_scale = 1.0, ignore_idx_value = -10,
             ignore_feature_value = 255, max_spixels = int(num_spixels)))

    ### Iteration-1
    n.spixel_feat1 = exec_iter(n.init_spixel_feat, n.trans_features,
                               n.spixel_init, num_spixels_h,
                               num_spixels_w, num_spixels, trans_dim)

    ### Iteration-2
    n.spixel_feat2 = exec_iter(n.spixel_feat1, n.trans_features,
                               n.spixel_init, num_spixels_h,
                               num_spixels_w, num_spixels, trans_dim)

    ### Iteration-3
    n.spixel_feat3 = exec_iter(n.spixel_feat2, n.trans_features,
                               n.spixel_init, num_spixels_h,
                               num_spixels_w, num_spixels, trans_dim)

    ### Iteration-4
    n.spixel_feat4 = exec_iter(n.spixel_feat3, n.trans_features,
                               n.spixel_init, num_spixels_h,
                               num_spixels_w, num_spixels, trans_dim)

    if num_steps == 5:
        ### Iteration-5
        n.final_pixel_assoc  = \
            compute_assignments(n.spixel_feat4, n.trans_features,
                                n.spixel_init, num_spixels_h,
                                num_spixels_w, num_spixels, trans_dim)

    elif num_steps == 10:
        ### Iteration-5
        n.spixel_feat5 = exec_iter(n.spixel_feat4, n.trans_features,
                                   n.spixel_init, num_spixels_h,
                                   num_spixels_w, num_spixels, trans_dim)

        ### Iteration-6
        n.spixel_feat6 = exec_iter(n.spixel_feat5, n.trans_features,
                                   n.spixel_init, num_spixels_h,
                                   num_spixels_w, num_spixels, trans_dim)

        ### Iteration-7
        n.spixel_feat7 = exec_iter(n.spixel_feat6, n.trans_features,
                                   n.spixel_init, num_spixels_h,
                                   num_spixels_w, num_spixels, trans_dim)

        ### Iteration-8
        n.spixel_feat8 = exec_iter(n.spixel_feat7, n.trans_features,
                                   n.spixel_init, num_spixels_h,
                                   num_spixels_w, num_spixels, trans_dim)

        ### Iteration-9
        n.spixel_feat9 = exec_iter(n.spixel_feat8, n.trans_features,
                                   n.spixel_init, num_spixels_h,
                                   num_spixels_w, num_spixels, trans_dim)

        ### Iteration-10
        # 得到超像素与像素之间的软链接
        n.final_pixel_assoc  = \
            compute_assignments(n.spixel_feat9, n.trans_features,
                                n.spixel_init, num_spixels_h,
                                num_spixels_w, num_spixels, trans_dim)


    if phase == 'TRAIN' or phase == 'TEST':

        # Compute final spixel features
        # 紧凑型损失
        n.new_spixel_feat = L.SpixelFeature2(n.pixel_features,
                                             n.final_pixel_assoc,
                                             n.spixel_init,
                                             spixel_feature2_param =\
            dict(num_spixels_h = num_spixels_h, num_spixels_w = num_spixels_w))


        # 得到最后的超像素标签
        #计算最后的超像素与像素的联系
        n.new_spix_indices = compute_final_spixel_labels(n.final_pixel_assoc,
                                                         n.spixel_init,
                                                         num_spixels_h, num_spixels_w)

        # superpixel_pooling
        n.superpixel_pooling_out, n.superpixel_seg_label = L.SuperpixelPooling(n.conv_dsp, n.seg_label,
                                                                               n.new_spix_indices,
                                                                               superpixel_pooling_param=dict(
                                                                                   pool_type=P.Pooling.AVE), ntop=2)

        n.loss0 = L.SimilarityLoss(n.superpixel_pooling_out, n.superpixel_seg_label, n.new_spix_indices,
                                      loss_weight=0.1, similarity_loss_param=dict(sample_points=1))


        n.recon_feat2 = L.Smear(n.new_spixel_feat, n.new_spix_indices,
                                propagate_down = [True, False])
        n.loss1, n.loss2 = position_color_loss(n.recon_feat2, n.pixel_features,
                                               pos_weight = 0.00001,
                                               col_weight = 0.0)


        # Convert pixel labels to spixel labels
        # 任务特征重建损失
        # 将像素标签转化为超像素标签(这里应该是硬链接,用来计算损失函数)
        # 个人感觉spixel_label和上面的
        n.spixel_label = L.SpixelFeature2(n.problabel,
                                          n.final_pixel_assoc,
                                          n.spixel_init,
                                          spixel_feature2_param =\
            dict(num_spixels_h = num_spixels_h, num_spixels_w = num_spixels_w))
        # Convert spixel labels back to pixel labels
        # 将超像素标签转回到像素标签
        n.recon_label = decode_features(n.final_pixel_assoc, n.spixel_label, n.spixel_init,
                                        num_spixels_h, num_spixels_w, num_spixels, num_channels = 50)

        n.recon_label = L.ReLU(n.recon_label, in_place = True)
        n.recon_label2 = L.Power(n.recon_label, power_param = dict(shift = 1e-10))
        n.recon_label3 = normalize(n.recon_label2, 50)
        n.loss3 = L.LossWithoutSoftmax(n.recon_label3, n.label,
                                       loss_param = dict(ignore_label = 255),
                                       loss_weight = 10)

    else:
        n.new_spix_indices = compute_final_spixel_labels(n.final_pixel_assoc,
                                                         n.spixel_init,
                                                         num_spixels_h, num_spixels_w)
        n.segmentation = L.EgbSegment(n.conv_dsp, n.new_spix_indices, n.bound_param, n.minsize_param,
                                      egb_segment_param=dict(bound=3, min_size=10))
#  NetSpec 是包含Tops(可以直接赋值作为属性)的集合。调用 NetSpec.to_proto 创建包含所有层(layers)的网络参数,这些层(layers)需要被赋值,并使用被赋值的名字。
    return n.to_proto()