# Create CRF model and add potentials
# =============================================================================
#zero_unsure:  whether zero is a class, if its False, it means zero canb be any of other classes
# =============================================================================
# crf = crf_model.DenseCRF(
#     num_classes = 3,
#     zero_unsure = True,              # The number of output classes
#     unary_potential=unary,
#     pairwise_potentials=[bilateral_pairwise, gaussian_pairwise],
#     use_2d = 'rgb-2d'                #'rgb-1d' or 'rgb-2d' or 'non-rgb'
# )
# =============================================================================
crf = crf_model.DenseCRF(
    num_classes=5,
    zero_unsure=False,  # The number of output classes
    unary_potential=unary,
    pairwise_potentials=[bilateral_pairwise, gaussian_pairwise],
    use_2d='rgb-1d'  #'rgb-1d' or 'rgb-2d' or 'non-rgb'
)


def crfing(rootpath, image, probabilities, count):
    # =============================================================================
    # Set the CRF model
    # =============================================================================
    #label_source: whether label is from softmax, or other type of label.
    crf.set_image(
        image=image,
        probabilities=probabilities,
        colour_axis=
        -1,  # The axis corresponding to colour in the image numpy shape
bilateral_pairwise = potentials.BilateralPotential(
    sdims=10,
    schan=0.01,
    compatibility=10,
    kernel=dense_crf.DIAG_KERNEL,
    normalization=dense_crf.NORMALIZE_SYMMETRIC)

# =============================================================================
# Create CRF model and add potentials
# =============================================================================
#zero_unsure:  whether zero is a class, if its False, it means zero canb be any of other classes
crf = crf_model.DenseCRF(
    num_classes=2,  # The number of output classes
    zero_unsure=False,
    unary_potential=unary,
    pairwise_potentials=bilateral_pairwise,
    use_2d='non-rgb'  #'rgb-1d' or 'rgb-2d' or 'non-rgb'
)

# =============================================================================
# Load image and probabilities, use the original progect code
# =============================================================================
H, W, NLABELS = 400, 512, 2
# This creates a gaussian blob...
pos = np.stack(np.mgrid[0:H, 0:W], axis=2)
rv = multivariate_normal([H // 2, W // 2], (H // 4) * (W // 4))
probs = rv.pdf(pos)
# ...which we project into the range [0.4, 0.6]
probs = (probs - probs.min()) / (probs.max() - probs.min())
probs = 0.5 + 0.2 * (probs - 0.5)