def create_unet(image,
                label,
                n_class,
                filter_size,
                num_of_feature,
                num_of_layers,
                keep_prob,
                name,
                reg_weight=0.001,
                debug=True,
                restore=False,
                weights=None,
                ClassWeights=[1, 1, 1]):

    with tf.name_scope("u-net"):
        y_conv, variables, layer_id, dw_h_convs = unet(
            image, n_class, filter_size, num_of_feature, num_of_layers,
            keep_prob, name, debug, restore, weights)
        clean_y_out = tf.reshape(
            tf.nn.softmax(tf.reshape(y_conv, [-1, n_class])), tf.shape(y_conv),
            'segmentation_map')  #softmax y output

        # summary
        if debug:
            utils.add_activation_summary(clean_y_out)
            utils.add_to_image_summary(clean_y_out)
            for var in variables:
                utils.add_to_regularization_and_summary(var)

    with tf.name_scope("loss"):
        # adding class weight
        # loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels = tf.reshape(label, [-1]), logits = tf.multiply(tf.reshape(y_conv, [-1, n_class]), ClassWeights)))

        loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(
                labels=tf.reshape(label, [-1, n_class]),
                logits=tf.reshape(y_conv, [-1, n_class])))
        weight_decay = 0
        if reg_weight != None:
            for var in variables:
                weight_decay = weight_decay + tf.nn.l2_loss(var)

        loss = tf.reduce_sum(loss + reg_weight * weight_decay, name='loss')
        if debug:
            utils.add_scalar_summary(loss)

    return loss, clean_y_out, variables, dw_h_convs
def NTN(image,
        n_class,
        filter_size,
        num_of_feature,
        num_of_layers,
        keep_prob,
        name,
        debug,
        restore=False,
        weights=None,
        Unsupervised=False):

    with tf.name_scope("u-net"):

        y_conv, variables, layer_id, dw_h_convs = unet(
            image, n_class, filter_size, num_of_feature, num_of_layers,
            keep_prob, name, debug, restore, weights)
        clean_y_out = tf.reshape(
            tf.nn.softmax(tf.reshape(y_conv, [-1, n_class])), tf.shape(y_conv),
            'segmentation_map')  #softmax y output

        # summary
        if debug:
            utils.add_activation_summary(clean_y_out)
            utils.add_to_image_summary(clean_y_out)

    with tf.name_scope("trans-layer"):
        if Unsupervised == False:
            w = utils.get_variable(
                np.reshape(np.eye(n_class), [1, 1, n_class, n_class]),
                'trans-prob-weight')
        else:
            w = utils.weight_constant(
                np.reshape(np.ones(
                    (n_class, n_class)), [1, 1, n_class, n_class]) * 1. /
                n_class, 'trans-prob-weight')
        TransProbVar = []
        noise_y_out = tf.nn.conv2d(clean_y_out,
                                   w,
                                   strides=[1, 1, 1, 1],
                                   padding="SAME",
                                   name="NoisySegMap")
        TransProbVar.append(w)
        if debug:
            utils.add_activation_summary(noise_y_out)
            utils.add_to_image_summary(noise_y_out)

    with tf.name_scope("MapTransProb"):
        if Unsupervised == False:
            MIN = tf.zeros([n_class, n_class], dtype=tf.float32)
            MAX = tf.ones([n_class, n_class], dtype=tf.float32)
            I_MIN = tf.maximum(w[0, 0], MIN, name="MAXIMUM")
            I_MAX = tf.minimum(I_MIN, MAX, name="MINIMUM")
            B = tf.reshape(
                I_MAX / tf.reshape((tf.reduce_sum(I_MAX, 1)), [n_class, 1]),
                [1, 1, n_class, n_class])
            MapTransProb = tf.assign(w, B)
        else:
            MapTransProb = None

    return noise_y_out, clean_y_out, MapTransProb, variables, TransProbVar, dw_h_convs
def unet_upsample(image,
                  dw_h_convs,
                  variables,
                  layer_id,
                  weight_id,
                  filter_size,
                  num_of_feature,
                  num_of_layers,
                  keep_prob,
                  name,
                  debug,
                  restore=False,
                  weights=None):
    new_variables = []
    in_node = dw_h_convs[num_of_layers - 1]
    # upsample layer
    for layer in range(num_of_layers - 2, -1, -1):
        features = 2**(layer + 1) * num_of_feature
        stddev = 0.02

        wd_name = name + '_layer_up' + str(layer_id) + '_w'
        bd_name = name + '_layer_up' + str(layer_id) + '_b'
        w1_name = name + '_layer_up_conv' + str(layer_id) + '_w0'
        w2_name = name + '_layer_up_conv' + str(layer_id) + '_w1'
        b1_name = name + '_layer_up_conv' + str(layer_id) + '_b0'
        b2_name = name + '_layer_up_conv' + str(layer_id) + '_b1'
        relu_name = name + '_layer_up_conv' + str(layer_id) + '_feat'

        # pooling size is 2
        if restore == True:
            wd = utils.get_variable(weights[weight_id], wd_name)
            weight_id += 1
            bd = utils.get_variable(weights[weight_id], bd_name)
            weight_id += 1
            w1 = utils.get_variable(weights[weight_id], w1_name)
            weight_id += 1
            w2 = utils.get_variable(weights[weight_id], w2_name)
            weight_id += 1
            b1 = utils.get_variable(weights[weight_id], b1_name)
            weight_id += 1
            b2 = utils.get_variable(weights[weight_id], b2_name)
            weight_id += 1
        else:
            wd = utils.weight_variable([2, 2, features // 2, features], stddev,
                                       wd_name)
            bd = utils.bias_variable([features // 2], bd_name)
            w1 = utils.weight_variable(
                [filter_size, filter_size, features, features // 2], stddev,
                w1_name)
            w2 = utils.weight_variable(
                [filter_size, filter_size, features // 2, features // 2],
                stddev, w2_name)
            b1 = utils.bias_variable([features // 2], b1_name)
            b2 = utils.bias_variable([features // 2], b2_name)
        h_deconv = tf.nn.relu(
            utils.conv2d_transpose_strided(in_node,
                                           wd,
                                           bd,
                                           keep_prob=keep_prob))
        h_deconv_concat = utils.crop_and_concat(dw_h_convs[layer], h_deconv)
        conv1 = utils.conv2d_basic(h_deconv_concat, w1, b1, keep_prob)
        h_conv = tf.nn.relu(conv1)
        conv2 = utils.conv2d_basic(h_conv, w2, b2, keep_prob)

        in_node = tf.nn.relu(conv2, relu_name)
        if debug:
            utils.add_activation_summary(in_node)
            utils.add_to_image_summary(
                utils.get_image_summary(in_node, relu_name + '_image'))

        new_variables.extend((wd, bd, w1, w2, b1, b2))
        layer_id += 1
    return in_node, new_variables, layer_id, weight_id
def unet_downsample(image,
                    filter_size,
                    num_of_feature,
                    num_of_layers,
                    keep_prob,
                    name,
                    debug,
                    restore=False,
                    weights=None):
    channels = image.get_shape().as_list()[-1]
    dw_h_convs = {}
    variables = []
    pools = {}
    in_node = image

    # downsample layer
    layer_id = 0
    weight_id = 0
    for layer in range(0, num_of_layers):
        features = 2**layer * num_of_feature

        stddev = 0.02
        w1_name = name + '_layer_' + str(layer_id) + '_w_0'
        w2_name = name + '_layer_' + str(layer_id) + '_w_1'
        b1_name = name + '_layer_' + str(layer_id) + '_b_0'
        b2_name = name + '_layer_' + str(layer_id) + '_b_1'
        relu_name = name + '_layer_' + str(layer_id) + '_feat'
        if layer == 0:
            if restore == True:
                w1 = utils.get_variable(weights[weight_id], w1_name)
                weight_id += 1
            else:
                w1 = utils.weight_variable(
                    [filter_size, filter_size, channels, features], stddev,
                    w1_name)
        else:
            if restore == True:
                w1 = utils.get_variable(weights[weight_id], w1_name)
                weight_id += 1
            else:
                w1 = utils.weight_variable(
                    [filter_size, filter_size, features // 2, features],
                    stddev, w1_name)

        if restore == True:
            w2 = utils.get_variable(weights[weight_id], w2_name)
            weight_id += 1
            b1 = utils.get_variable(weights[weight_id], b1_name)
            weight_id += 1
            b2 = utils.get_variable(weights[weight_id], b2_name)
            weight_id += 1
        else:
            w2 = utils.weight_variable(
                [filter_size, filter_size, features, features], stddev,
                w2_name)
            b1 = utils.bias_variable([features], b1_name)
            b2 = utils.bias_variable([features], b2_name)

        conv1 = utils.conv2d_basic(in_node, w1, b1, keep_prob)
        tmp_h_conv = tf.nn.relu(conv1)
        conv2 = utils.conv2d_basic(tmp_h_conv, w2, b2, keep_prob)

        dw_h_convs[layer] = tf.nn.relu(conv2, relu_name)

        if layer < num_of_layers - 1:
            pools[layer] = utils.max_pool_2x2(dw_h_convs[layer])
            in_node = pools[layer]

        if debug:
            utils.add_activation_summary(dw_h_convs[layer])
            utils.add_to_image_summary(
                utils.get_image_summary(dw_h_convs[layer],
                                        relu_name + '_image'))

        variables.extend((w1, w2, b1, b2))

        layer_id += 1

    return dw_h_convs, variables, layer_id, weight_id
def autoencorder_antn(image,
                      n_class,
                      filter_size,
                      num_of_feature,
                      num_of_layers,
                      keep_prob,
                      debug,
                      restore=False,
                      shared_weights=None,
                      M_weights=None,
                      AE_weights=None):
    stddev = 0.02
    channels = image.get_shape().as_list()[-1]
    with tf.name_scope("shared-network"):
        name = 'shared'
        inter_feat, shared_variables, layer_id, weight_id = unet_downsample(
            image, filter_size, num_of_feature, num_of_layers, keep_prob, name,
            debug, restore, shared_weights)

    with tf.name_scope("main-network"):
        name = 'main-network'
        M_feat, M_variables, M_layer_id, M_weight_id = unet_upsample(
            image, inter_feat, shared_variables, layer_id, weight_id,
            filter_size, num_of_feature, num_of_layers, keep_prob, name, debug,
            restore, M_weights)
        w_name = name + '_final_layer_' + str(M_layer_id) + '_w'
        b_name = name + '_final_layer_' + str(M_layer_id) + '_b'
        relu_name = name + '_final_layer_' + str(M_layer_id) + '_feat'
        if restore == True:
            w = utils.get_variable(M_weights[M_weight_id], w_name)
            M_weight_id += 1
            b = utils.get_variable(M_weights[M_weight_id], b_name)
            M_weight_id += 1
        else:
            w = utils.weight_variable([1, 1, num_of_feature, n_class], stddev,
                                      w_name)
            M_weight_id += 1
            b = utils.bias_variable([n_class], b_name)
            M_weight_id += 1
        y_conv = utils.conv2d_basic(M_feat, w, b, keep_prob)
        y_conv_relu = tf.nn.relu(y_conv)
        clean_y_out = tf.reshape(
            tf.nn.softmax(tf.reshape(y_conv, [-1, n_class])), tf.shape(y_conv),
            'segmentation_map')
        M_variables.extend((w, b))
        if debug:
            utils.add_activation_summary(clean_y_out)
            utils.add_to_image_summary(clean_y_out)
        M_layer_id += 1

    with tf.name_scope("auto-encoder"):
        name = 'auto-encoder'
        # AE_conv, AE_variables, AE_layer_id, AE_weight_id = unet_upsample(image, inter_feat, shared_variables, layer_id, weight_id, filter_size,
        # 														   num_of_feature, num_of_layers, keep_prob, name, debug,
        # 														   restore, weights)

        w_name = name + '_final_layer_' + str(M_layer_id) + '_w'
        b_name = name + '_final_layer_' + str(M_layer_id) + '_b'
        relu_name = name + '_final_layer_' + str(M_layer_id) + '_feat'
        # contrating layer of main network as input
        # if restore == True:
        # 	w = utils.get_variable(weights[AE_weight_id], w_name)
        # 	AE_weight_id += 1
        # 	b = utils.get_variable(weights[AE_weight_id], b_name)
        # 	AE_weight_id += 1
        # else:
        # w = utils.weight_variable([1, 1, num_of_feature, channels], stddev, w_name)
        # AE_weight_id+=1
        # b = utils.bias_variable([channels], b_name)
        # AE_weight_id+=1
        # AE_feat = tf.nn.relu(utils.conv2d_basic(AE_conv, w, b, keep_prob), relu_name)
        # AE_variables.extend((w, b))
        # AE_layer_id+=1
        # last layer of main network as input
        AE_variables = []
        w = utils.weight_variable([1, 1, num_of_feature, channels], stddev,
                                  w_name)
        b = utils.bias_variable([channels], b_name)
        AE_feat = tf.nn.relu(utils.conv2d_basic(M_feat, w, b, keep_prob),
                             relu_name)
        AE_variables.extend((w, b))
        if debug:
            utils.add_activation_summary(AE_feat)
            utils.add_to_image_summary(
                utils.get_image_summary(AE_feat, relu_name + '_image'))

    with tf.name_scope("trans-layer"):
        # trans_variables = []
        name = 'trans_layer'
        # wd_name = name  + str(layer_id) + '_w'
        # bd_name = name  + str(layer_id) + '_b'
        # features = 2 ** (num_of_layers - 1) * num_of_feature
        # wd = utils.weight_variable([2, 2, n_class * n_class, features], stddev, wd_name)
        # bd = utils.bias_variable([features//2], bd_name)
        # output_shape = [tf.shape(inter_feat)[0], tf.shape(inter_feat)[1] * 4, tf.shape(inter_feat)[2] * 4, n_class * n_class]
        # tran_y_feat = tf.nn.relu(utils.conv2d_transpose_strided(in_node, wd, bd, output_shape=output_shape, keep_prob = keep_prob))
        # trans_variables.extend((wd, bd))

        tran_y_feat = _trans_layer(
            inter_feat[0], n_class * n_class,
            [image.get_shape().as_list()[1],
             image.get_shape().as_list()[2]])
        class_tran_y_out = []
        for i in range(n_class):
            class_tran_y_out.append(
                tf.reshape(tf.nn.softmax(
                    tf.reshape(
                        tran_y_feat[:, :, :,
                                    i * n_class:(i * n_class + n_class)],
                        [-1, n_class])),
                           tf.shape(clean_y_out),
                           name='tran_map' + str(i)))

        tran_map = tf.concat(class_tran_y_out, 3, name='tran_map')
        if debug:
            for i in range(n_class):
                for j in range(n_class):
                    # utils.add_activation_summary(utils.get_image_summary(tran_map,  str(i) + '_to_'+ str(j), i * n_class + j))
                    utils.add_activation_summary(tran_map[:, :, :,
                                                          i * n_class + j])
                    utils.add_to_image_summary(
                        utils.get_image_summary(tran_map,
                                                str(i) + '_to_' + str(j),
                                                i * n_class + j))

    with tf.name_scope("noisy-map-layer"):
        noise_y_out = tf.reshape(tf.matmul(
            tf.reshape(clean_y_out, [-1, 1, n_class]),
            tf.reshape(tran_map, [-1, n_class, n_class])),
                                 tf.shape(clean_y_out),
                                 name='noise_output')

        # summary
        if debug:
            utils.add_activation_summary(noise_y_out)
            utils.add_to_image_summary(noise_y_out)

    return noise_y_out, clean_y_out, y_conv, tran_y_feat, tran_map, AE_feat, shared_variables, AE_variables, M_variables, inter_feat
def AutoencorderCLustering(image,
                           filter_size,
                           num_of_feature,
                           num_of_layers,
                           keep_prob,
                           name,
                           debug,
                           Class,
                           restore=False,
                           weights=None):
    channels = image.get_shape().as_list()[-1]
    dw_h_convs = {}
    variables = []
    pools = {}
    in_node = image

    # downsample layer
    layer_id = 0
    weight_id = 0
    for layer in range(0, num_of_layers):
        features = 2**layer * num_of_feature
        stddev = np.sqrt(float(2) / (filter_size**2 * features))

        w1_name = name + '_layer_' + str(layer_id) + '_w_0'
        w2_name = name + '_layer_' + str(layer_id) + '_w_1'
        b1_name = name + '_layer_' + str(layer_id) + '_b_0'
        b2_name = name + '_layer_' + str(layer_id) + '_b_1'
        relu_name = name + '_layer_' + str(layer_id) + '_feat'
        if layer == 0:
            if restore == True:
                w1 = utils.get_variable(weights[weight_id], w1_name)
                weight_id += 1
            else:
                w1 = utils.weight_variable(
                    [filter_size, filter_size, channels, features], stddev,
                    w1_name)
        else:
            if restore == True:
                w1 = utils.get_variable(weights[weight_id], w1_name)
                weight_id += 1
            else:
                w1 = utils.weight_variable(
                    [filter_size, filter_size, features // 2, features],
                    stddev, w1_name)

        if restore == True:
            w2 = utils.get_variable(weights[weight_id], w2_name)
            weight_id += 1
            b1 = utils.get_variable(weights[weight_id], b1_name)
            weight_id += 1
            b2 = utils.get_variable(weights[weight_id], b2_name)
            weight_id += 1
        else:
            w2 = utils.weight_variable(
                [filter_size, filter_size, features, features], stddev,
                w2_name)
            b1 = utils.bias_variable([features], b1_name)
            b2 = utils.bias_variable([features], b2_name)

        conv1 = utils.conv2d_basic(in_node, w1, b1, keep_prob)
        tmp_h_conv = tf.nn.relu(conv1)
        conv2 = utils.conv2d_basic(tmp_h_conv, w2, b2, keep_prob)

        dw_h_convs[layer] = tf.nn.relu(conv2, relu_name)

        if layer < num_of_layers - 1:
            pools[layer] = utils.max_pool_2x2(dw_h_convs[layer])
            in_node = pools[layer]

        if debug:
            utils.add_activation_summary(dw_h_convs[layer])
            utils.add_to_image_summary(
                utils.get_image_summary(dw_h_convs[layer],
                                        relu_name + '_image'))

        variables.extend((w1, w2, b1, b2))

        layer_id += 1
    EncodedNode = dw_h_convs[num_of_layers - 1]

    # upsample layer
    Representation = []
    for k in range(Class):
        in_node = EncodedNode
        for layer in range(num_of_layers - 2, -1, -1):
            features = 2**(layer + 1) * num_of_feature
            stddev = np.sqrt(float(2) / (filter_size**2 * features))

            wd_name = name + '_layer_up' + str(
                layer_id) + '_w' + 'Class' + str(k)
            bd_name = name + '_layer_up' + str(
                layer_id) + '_b' + 'Class' + str(k)
            w1_name = name + '_layer_up_conv' + str(
                layer_id) + '_w0' + 'Class' + str(k)
            w2_name = name + '_layer_up_conv' + str(
                layer_id) + '_w1' + 'Class' + str(k)
            b1_name = name + '_layer_up_conv' + str(
                layer_id) + '_b0' + 'Class' + str(k)
            b2_name = name + '_layer_up_conv' + str(
                layer_id) + '_b1' + 'Class' + str(k)
            relu_name = name + '_layer_up_conv' + str(
                layer_id) + '_feat' + 'Class' + str(k)

            # pooling size is 2
            if restore == True:
                wd = utils.get_variable(weights[weight_id], wd_name)
                weight_id += 1
                bd = utils.get_variable(weights[weight_id], bd_name)
                weight_id += 1
                w1 = utils.get_variable(weights[weight_id], w1_name)
                weight_id += 1
                w2 = utils.get_variable(weights[weight_id], w2_name)
                weight_id += 1
                b1 = utils.get_variable(weights[weight_id], b1_name)
                weight_id += 1
                b2 = utils.get_variable(weights[weight_id], b2_name)
                weight_id += 1
            else:
                wd = utils.weight_variable([2, 2, features // 2, features],
                                           stddev, wd_name)
                bd = utils.bias_variable([features // 2], bd_name)
                w1 = utils.weight_variable(
                    [filter_size, filter_size, features, features // 2],
                    stddev, w1_name)
                w2 = utils.weight_variable(
                    [filter_size, filter_size, features // 2, features // 2],
                    stddev, w2_name)
                b1 = utils.bias_variable([features // 2], b1_name)
                b2 = utils.bias_variable([features // 2], b2_name)

            h_deconv = tf.nn.relu(
                utils.conv2d_transpose_strided(in_node, wd, bd))

            # h_deconv_concat = utils.crop_and_concat(dw_h_convs[layer], h_deconv, tf.shape(image)[0])
            h_deconv_concat = utils.crop_and_concat(dw_h_convs[layer],
                                                    h_deconv)

            conv1 = utils.conv2d_basic(h_deconv_concat, w1, b1, keep_prob)
            h_conv = tf.nn.relu(conv1)
            conv2 = utils.conv2d_basic(h_conv, w2, b2, keep_prob)

            in_node = tf.nn.relu(conv2, relu_name)
            if debug:

                utils.add_to_image_summary(
                    utils.get_image_summary(in_node, relu_name + '_image'))
                utils.add_to_image_summary(
                    utils.get_image_summary(conv2, relu_name + '_image'))

            variables.extend((wd, bd, w1, w2, b1, b2))
            layer_id += 1

        w_name = name + '_final_layer_' + str(layer_id) + '_w' + str(k)
        b_name = name + '_final_layer_' + str(layer_id) + '_b' + str(k)
        relu_name = name + '_final_layer_' + str(layer_id) + '_feat' + str(k)

        if restore == True:
            w = utils.get_variable(weights[weight_id], w_name)
            weight_id += 1
            b = utils.get_variable(weights[weight_id], b_name)
            weight_id += 1
        else:
            w = utils.weight_variable([1, 1, num_of_feature, channels], stddev,
                                      w_name)
            b = utils.bias_variable([channels], b_name)

        y_conv = tf.nn.relu(utils.conv2d_basic(in_node, w, b), relu_name)

        variables.extend((w, b))
        if debug:
            utils.add_activation_summary(y_conv)
            utils.add_to_image_summary(
                utils.get_image_summary(y_conv, relu_name))

        Representation.append(y_conv)

    return Representation, variables, dw_h_convs
def ANTN(image,
         label,
         n_class,
         filter_size,
         num_of_branch,
         num_of_feature,
         num_of_layers,
         clean_network_hidden,
         trans_network_hidden,
         debug,
         keep_prob=1.0,
         restore_clean=False,
         restore_tran=False,
         weights=None,
         all_tran_var=None):
    with tf.name_scope("clean-network"):
        clean_y_feat, clean_var, layer_id = unet(image, n_class, filter_size,
                                                 num_of_feature, num_of_layers,
                                                 keep_prob, 'main', debug,
                                                 restore_clean, weights)
        clean_y_out = tf.clip_by_value(
            tf.reshape(tf.nn.softmax(tf.reshape(clean_y_feat, [-1, n_class])),
                       tf.shape(clean_y_feat), 'clean_map'), 1e-6,
            1.0)  #softmax y output
        # summary
        if debug:
            utils.add_activation_summary(clean_y_out)
            utils.add_to_image_summary(clean_y_out)
            for var in clean_var:
                utils.add_to_regularization_and_summary(var)

    # branch process
    all_tran_y_out = []
    all_tran_var = []
    for branch in range(num_of_branch):
        with tf.name_scope("transition-network" + str(branch)):
            if restore_trans == True:
                tran_y_feat, tran_var, layer_id = unet(
                    image, n_class * n_class, filter_size, num_of_feature,
                    num_of_layers, keep_prob, 'tran' + str(branch), debug,
                    restore_tran, all_tran_var[branch])
            else:
                tran_y_feat, tran_var, layer_id = unet(
                    image, n_class * n_class, filter_size, num_of_feature,
                    num_of_layers, keep_prob, 'tran' + str(branch), debug)
            class_tran_y_out = []
            for i in range(n_class):
                class_tran_y_out.append(
                    tf.reshape(tf.nn.softmax(
                        tf.reshape(
                            tran_y_feat[:, :, :,
                                        i * n_class:(i * n_class + n_class)],
                            [-1, n_class])),
                               tf.shape(clean_y_feat),
                               name='tran_map' + str(i)))

            tran_y_out = tf.clip_by_value(
                tf.concat(class_tran_y_out, 3, name='tran_map'), 1e-6, 1.0)
            all_tran_y_out.append(tran_y_out)
            all_tran_var.append(tran_var)
            # summary
            if debug:
                # for clss in class_tran_y_out:
                # 	utils.add_activation_summary(clss)
                # 	utils.add_to_image_summary(clss)
                # for var in tran_var:
                # 	utils.add_to_regularization_and_summary(var)
                for i in range(n_class):
                    for j in range(n_class):
                        z = tf.identity(tran_y_out[:, :, :, i],
                                        name=str(i) + 'to' + str(j))
                        utils.add_activation_summary(z)
                for var in tran_var:
                    utils.add_to_regularization_and_summary(var)
    # branch process
    all_noise_y_out = []
    for branch in range(num_of_branch):
        with tf.name_scope("integration" + str(branch)):
            noise_y_out = tf.reshape(tf.matmul(
                tf.reshape(clean_y_out, [-1, 1, n_class]),
                tf.reshape(all_tran_y_out[branch], [-1, n_class, n_class])),
                                     tf.shape(clean_y_feat),
                                     name='noise_output')
            all_noise_y_out.append(noise_y_out)
            # summary
            if debug:
                utils.add_activation_summary(noise_y_out)
                utils.add_to_image_summary(noise_y_out)

    with tf.name_scope("loss"):
        # clean_net_loss = utils.cross_entropy(tf.reshape(clean_network_hidden, [-1, n_class]), tf.reshape(clean_y_out, [-1, n_class]), 'clean_net_loss')
        # noise_loss = utils.cross_entropy(tf.cast(tf.reshape(label, [-1, n_class]), tf.float32), tf.reshape(noise_y_out, [-1, n_class]), 'noise_loss')
        # trans_net_loss = utils.cross_entropy(tf.reshape(trans_network_hidden, [-1, n_class * n_class]), tf.reshape(tran_y_out, [-1, n_class * n_class]), 'trans_net_loss')

        # clean_net_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = tf.reshape(clean_network_hidden, [-1, n_class]),
        # 																		logits = tf.reshape(clean_y_feat, [-1, n_class])),
        # 																		name = 'clean_net_loss')

        clean_net_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=tf.reshape(clean_network_hidden, [-1]),
                logits=tf.reshape(clean_y_feat, [-1, n_class])),
            name='clean_net_loss')
        # branch process
        all_noise_loss = []
        all_trans_net_loss = []
        for branch in range(num_of_branch):
            noise_loss = utils.cross_entropy(
                tf.cast(tf.reshape(label[branch], [-1, n_class]), tf.float32),
                tf.reshape(all_noise_y_out[branch], [-1, n_class]),
                'noise_loss')
            trans_net_loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(
                    labels=tf.reshape(trans_network_hidden[branch],
                                      [-1, n_class * n_class]),
                    logits=tf.reshape(all_tran_y_out[branch],
                                      [-1, n_class * n_class])),
                name='trans_net_loss')
            all_noise_loss.append(noise_loss)
            all_trans_net_loss.append(trans_net_loss)

            if debug:
                utils.add_scalar_summary(noise_loss)
    return all_noise_loss, clean_net_loss, all_trans_net_loss, clean_y_out, all_tran_y_out, all_noise_y_out, clean_var, all_tran_var
Esempio n. 8
0
def main(CHANNEL,
         NClass,
         FILTER_SIZE,
         NUM_OF_FEATURE,
         NUM_OF_LAYERS,
         NoisyInput,
         NoisyOutput,
         CleanInput,
         CleanOutput,
         MAX_EPOCH,
         BatchSize,
         KEEP_PROB,
         REG_WEIGHT,
         LearningRate,
         RESTORE,
         SizeLimitation,
         DirSave,
         DirLoad=None,
         Debug=True):

    ImageSize = [NoisyInput.shape[1], NoisyInput.shape[2]]
    OutputImageSize = [NoisyOutput.shape[1], NoisyOutput.shape[2]]
    NAME = 'unet'
    tf.reset_default_graph()
    with tf.name_scope("Input"):
        image = tf.placeholder(
            tf.float32,
            shape=[None, ImageSize[0], ImageSize[1], CHANNEL],
            name="input_image")
        NoisyLabel = tf.placeholder(
            tf.int64,
            shape=[None, OutputImageSize[0], OutputImageSize[1], NClass],
            name="NoisyLabel")
        CleanLabel = tf.placeholder(
            tf.int64,
            shape=[None, OutputImageSize[0], OutputImageSize[1], NClass],
            name="CleanLabel")
        keep_prob = tf.placeholder(tf.float32, shape=[], name="keep_prob")
        utils.add_to_image_summary(image)

        # 50 data are split to training set and validatation set
        # 70% training set and 30% validation set

        # if np.max(image_data) > 1:
        # 	# image_data = image_data
        # 	image_data = model.normalize(image_data)

        # size = np.shape(NoisyInput)[0]
        # tr_size = np.int(size * 0.7)
        # BATCHES = tr_size
        # val_size = size - tr_size

        # tr_image_data = ImageData[0 : tr_size]
        # tr_label_data = NoisyData[0 : tr_size]
        # TrCleanLabel = CleanData[0 : tr_size]

        # val_image_data = ImageData[tr_size : size]
        # val_label_data = NoisyData[tr_size : size]
        # ValCleanLabel = CleanData[tr_size : size]

        BATCHES = np.shape(NoisyInput)[0]

        tr_image_data = NoisyInput
        tr_label_data = NoisyOutput

        val_image_data = CleanInput
        val_label_data = CleanOutput

    with tf.name_scope("net"):

        if RESTORE == True:
            weights = np.load(DirLoad)
            NoisyLoss, CleanOut, variables, dw_h_convs = model.create_unet(
                image,
                NoisyLabel,
                NClass,
                FILTER_SIZE,
                NUM_OF_FEATURE,
                NUM_OF_LAYERS,
                keep_prob,
                NAME,
                REG_WEIGHT,
                Debug,
                restore=RESTORE,
                weights=weights)

            print("Model restored...")
        else:
            NoisyLoss, CleanOut, variables, dw_h_convs = model.create_unet(
                image, NoisyLabel, NClass, FILTER_SIZE, NUM_OF_FEATURE,
                NUM_OF_LAYERS, keep_prob, NAME, REG_WEIGHT, Debug)

        CleanLoss = utils.cross_entropy(
            tf.cast(tf.reshape(CleanLabel, [-1, NClass]), tf.float32),
            tf.reshape(CleanOut, [-1, NClass]), 'Cleanloss')

        # utils.add_scalar_summary(CleanLoss)

        # NoiseAcc = tf.reduce_mean(tf.cast(tf.reshape(tf.equal(NoisyLabel, tf.argmax(CleanOut, 3)), [-1]), tf.float32), name = 'NoiseAcc')
        NoiseAcc = tf.reduce_mean(tf.cast(
            tf.reshape(
                tf.equal(tf.argmax(NoisyLabel, 3), tf.argmax(CleanOut, 3)),
                [-1]), tf.float32),
                                  name='NoiseAcc')
        CleanAcc = tf.reduce_mean(tf.cast(
            tf.equal(tf.argmax(CleanLabel, 3), tf.argmax(CleanOut, 3)),
            tf.float32),
                                  name='CleanAcc')

        utils.add_scalar_summary(NoiseAcc)
        utils.add_scalar_summary(CleanAcc)
        utils.add_scalar_summary(CleanLoss)

    with tf.name_scope("Train"):

        trainable_var = tf.trainable_variables()
        train_op = train(NoisyLoss, trainable_var, LearningRate, Debug)
        print("Setting up summary op...")

        summary_op = tf.summary.merge_all()

        # uncomment BELOW TO RUNNING ON CPU
        # pdb.set_trace()
        # config = tf.ConfigProto(device_count = {'GPU': 0})
        # sess = tf.Session(config=config)
        # uncomment to run on GPU
        sess = tf.Session()
        ###############################

        print("Setting up Saver...")
        saver = tf.train.Saver()
        summary_writer = tf.summary.FileWriter(DirSave, sess.graph)

        #################
        # Insert code of data file checking here
        #################

        sess.run(tf.global_variables_initializer())
        tr_image_batch1 = tr_image_data[0:SizeLimitation]
        tr_label_batch1 = tr_label_data[0:SizeLimitation]

        val_image_batch = val_image_data[0:SizeLimitation]
        val_label_batch = val_label_data[0:SizeLimitation]

        total_iter = 0
        for epoch in range(MAX_EPOCH):
            for batch in range(0, BATCHES / BatchSize):
                # for batch in [0]:
                # image: [batch, row, col, channel]
                # label: [batch, row, col, n_class]
                tr_image_batch = tr_image_data[batch *
                                               BatchSize:batch * BatchSize +
                                               BatchSize]
                tr_label_batch = tr_label_data[batch *
                                               BatchSize:batch * BatchSize +
                                               BatchSize]

                tr_feed_dict = {
                    image: tr_image_batch,
                    NoisyLabel: tr_label_batch,
                    keep_prob: np.float32(KEEP_PROB)
                }
                tr_feed_dict1 = {
                    image: tr_image_batch1,
                    NoisyLabel: tr_label_batch1,
                    CleanLabel: val_label_batch,
                    keep_prob: np.float32(KEEP_PROB)
                }
                val_feed_dict = {
                    image: val_image_batch,
                    CleanLabel: val_label_batch,
                    keep_prob: np.float32(KEEP_PROB)
                }
                # pdb.set_trace()

                # trainining set
                if (total_iter) % 10 == 0:

                    # pre_seg, _NoisyLoss, _CleanLoss, _CleanAcc, _NoiseAcc, tr_variables, summary_str = sess.run([CleanOut, NoisyLoss, CleanLoss, CleanAcc, NoiseAcc,
                    #																							   variables, summary_op], feed_dict = tr_feed_dict1)
                    _dw_h_convs, _NoisyLoss, pre_seg, _NoiseAcc, tr_variables, summary_str, = sess.run(
                        [
                            dw_h_convs, NoisyLoss, CleanOut, NoiseAcc,
                            variables, summary_op
                        ],
                        feed_dict=tr_feed_dict1)

                    # print("Iter: %d, TrainNoisyLoss: %g, TrainNoiseAcc: %g" % (total_iter, _NoisyLoss, _NoiseAcc))
                    summary_writer.add_summary(summary_str, total_iter)
                    saver.save(sess, DirSave + "model.ckpt", total_iter)
                    np.save(DirSave + "weights", tr_variables)

                # validation set
                if (total_iter) % 50 == 0:
                    # _NoisyLoss, _CleanLoss, _NoiseAcc, _CleanAcc = sess.run([NoisyLoss, CleanLoss, NoiseAcc, CleanAcc], feed_dict = val_feed_dict)
                    # print("Iter: %d, ValNoisyLoss: %g, ValCleanLoss: %g, ValNoiseAcc: %g, ValCleanAcc: %g, curent_time: %s" %
                    # 	  (total_iter, _NoisyLoss, _CleanLoss, _NoiseAcc, _CleanAcc, str(datetime.datetime.now())))

                    _CleanLoss, _CleanAcc = sess.run([CleanLoss, CleanAcc],
                                                     feed_dict=val_feed_dict)
                    # print("Iter: %d, ValCleanLoss: %g, ValCleanAcc: %g, curent_time: %s" %
                    # 	  (total_iter, _CleanLoss, _CleanAcc, str(datetime.datetime.now())))
                sess.run(train_op, feed_dict=tr_feed_dict)
                total_iter += 1

            new_index = random.sample(range(BATCHES), BATCHES)
            tr_image_data = tr_image_data[new_index]
            tr_label_data = tr_label_data[new_index]

    sess.close()
    return _CleanAcc, _NoiseAcc