コード例 #1
0
def single_resblock(adain_use, is_training, residual_device, initializer,
                    scope, weight_decay, weight_decay_rate, x, layer, style,
                    filters, other_info):

    if not adain_use:
        norm1 = batch_norm(x=x,
                           is_training=is_training,
                           scope="layer%d_bn1" % layer,
                           parameter_update_device=residual_device)
    else:
        if other_info == 'DenseMixer':
            travel_times = int(int(x.shape[3]) / int(style.shape[4]))
            style_tile = tf.tile(style, [1, 1, 1, 1, travel_times])
            norm1 = adaptive_instance_norm(content=x, style=style_tile)
        elif other_info == 'ResidualMixer':
            norm1 = adaptive_instance_norm(content=x, style=style)

    act1 = relu(norm1)
    conv1 = conv2d(x=act1,
                   output_filters=filters,
                   scope="layer%d_conv1" % layer,
                   parameter_update_device=residual_device,
                   kh=3,
                   kw=3,
                   sh=1,
                   sw=1,
                   initializer=initializer,
                   weight_decay=weight_decay,
                   name_prefix=scope,
                   weight_decay_rate=weight_decay_rate)
    if not adain_use:
        norm2 = batch_norm(x=conv1,
                           is_training=is_training,
                           scope="layer%d_bn2" % layer,
                           parameter_update_device=residual_device)
    else:

        norm2 = adaptive_instance_norm(content=conv1, style=style)
    act2 = relu(norm2)
    conv2 = conv2d(x=act2,
                   output_filters=filters,
                   scope="layer%d_conv2" % layer,
                   parameter_update_device=residual_device,
                   initializer=initializer,
                   weight_decay=weight_decay,
                   name_prefix=scope,
                   weight_decay_rate=weight_decay_rate,
                   kh=3,
                   kw=3,
                   sh=1,
                   sw=1)

    if other_info == 'ResidualMixer':
        output = x + conv2
    elif other_info == 'DenseMixer':
        output = conv2

    return output
コード例 #2
0
def encoder_adobenet_framework(images,
                               is_training,
                               encoder_device,
                               scope,initializer,
                               weight_decay,weight_decay_rate,
                               residual_at_layer=-1,
                               residual_connection_mode=None,
                               reuse=False,
                               adain_use=False):
    residual_connection_mode = None
    residual_at_layer = -1
    adain_use = False
    full_feature_list = list()

    with tf.variable_scope(tf.get_variable_scope()):
        with tf.device(encoder_device):
            with tf.variable_scope(scope):
                if reuse:
                    tf.get_variable_scope().reuse_variables()
                conv1 = relu(conv2d(x=images,
                                    output_filters=64,
                                    kh=7,kw=7, sh=1, sw=1,
                                    scope="layer%d_conv" % 1,
                                    parameter_update_device=encoder_device,
                                    initializer=initializer,
                                    weight_decay=weight_decay,
                                    name_prefix=scope,
                                    weight_decay_rate=weight_decay_rate))
                full_feature_list.append(conv1)

                return_str = "AdobeNet-Encoder %d Layers" % (len(full_feature_list))

    return conv1, -1, -1, full_feature_list, return_str
コード例 #3
0
ファイル: vggs.py プロジェクト: falconjhc/W-Net
def vgg_16_net(image,
               batch_size,
               device,
               keep_prob,
               initializer,
               reuse=False,
               network_usage='-1',
               output_high_level_features=[-1]):
    is_training = False
    weight_decay = False
    return_str = "Vgg16Net"
    weight_decay_rate = eps

    usage_scope = network_usage + '/ext_vgg16net'

    with tf.variable_scope(usage_scope):
        if reuse:
            tf.get_variable_scope().reuse_variables()

        features = list()

        ## block 1
        conv1_1 = relu(
            batch_norm(x=conv2d(x=image,
                                output_filters=64,
                                kh=3,
                                kw=3,
                                sh=1,
                                sw=1,
                                padding='SAME',
                                parameter_update_device=device,
                                weight_decay=weight_decay,
                                initializer=initializer,
                                scope='conv1_1',
                                weight_decay_rate=weight_decay_rate),
                       is_training=is_training,
                       scope='bn1_1',
                       parameter_update_device=device))

        conv1_2 = conv2d(x=conv1_1,
                         output_filters=64,
                         kh=3,
                         kw=3,
                         sh=1,
                         sw=1,
                         padding='SAME',
                         parameter_update_device=device,
                         weight_decay=weight_decay,
                         initializer=initializer,
                         scope='conv1_2',
                         weight_decay_rate=weight_decay_rate)
        if 1 in output_high_level_features:
            features.append(conv1_2)
        conv1_2 = relu(
            batch_norm(x=conv1_2,
                       is_training=is_training,
                       scope='bn1_2',
                       parameter_update_device=device))
        pool1 = tf.nn.max_pool(value=conv1_2,
                               ksize=[1, 2, 2, 1],
                               strides=[1, 2, 2, 1],
                               padding='SAME',
                               name='pool1')

        ## block 2
        conv2_1 = relu(
            batch_norm(x=conv2d(x=pool1,
                                output_filters=128,
                                kh=3,
                                kw=3,
                                sh=1,
                                sw=1,
                                padding='SAME',
                                parameter_update_device=device,
                                weight_decay=weight_decay,
                                initializer=initializer,
                                scope='conv2_1',
                                weight_decay_rate=weight_decay_rate),
                       is_training=is_training,
                       scope='bn2_1',
                       parameter_update_device=device))
        conv2_2 = conv2d(x=conv2_1,
                         output_filters=128,
                         kh=3,
                         kw=3,
                         sh=1,
                         sw=1,
                         padding='SAME',
                         parameter_update_device=device,
                         weight_decay=weight_decay,
                         initializer=initializer,
                         scope='conv2_2',
                         weight_decay_rate=weight_decay_rate)
        if 2 in output_high_level_features:
            features.append(conv2_2)
        conv2_2 = relu(
            batch_norm(x=conv2_2,
                       is_training=is_training,
                       scope='bn2_2',
                       parameter_update_device=device))

        pool2 = tf.nn.max_pool(value=conv2_2,
                               ksize=[1, 2, 2, 1],
                               strides=[1, 2, 2, 1],
                               padding='SAME',
                               name='pool2')

        ## block 3
        conv3_1 = relu(
            batch_norm(x=conv2d(x=pool2,
                                output_filters=256,
                                kh=3,
                                kw=3,
                                sh=1,
                                sw=1,
                                padding='SAME',
                                parameter_update_device=device,
                                weight_decay=weight_decay,
                                initializer=initializer,
                                scope='conv3_1',
                                weight_decay_rate=weight_decay_rate),
                       is_training=is_training,
                       scope='bn3_1',
                       parameter_update_device=device))
        conv3_2 = relu(
            batch_norm(x=conv2d(x=conv3_1,
                                output_filters=256,
                                kh=3,
                                kw=3,
                                sh=1,
                                sw=1,
                                padding='SAME',
                                parameter_update_device=device,
                                weight_decay=weight_decay,
                                initializer=initializer,
                                scope='conv3_2',
                                weight_decay_rate=weight_decay_rate),
                       is_training=is_training,
                       scope='bn3_2',
                       parameter_update_device=device))
        conv3_3 = conv2d(x=conv3_2,
                         output_filters=256,
                         kh=3,
                         kw=3,
                         sh=1,
                         sw=1,
                         padding='SAME',
                         parameter_update_device=device,
                         weight_decay=weight_decay,
                         initializer=initializer,
                         scope='conv3_3',
                         weight_decay_rate=weight_decay_rate)
        if 3 in output_high_level_features:
            features.append(conv3_3)
        conv3_3 = relu(
            batch_norm(x=conv3_3,
                       is_training=is_training,
                       scope='bn3_3',
                       parameter_update_device=device))
        pool3 = tf.nn.max_pool(value=conv3_3,
                               ksize=[1, 2, 2, 1],
                               strides=[1, 2, 2, 1],
                               padding='SAME',
                               name='pool3')

        ## block 4
        conv4_1 = relu(
            batch_norm(x=conv2d(x=pool3,
                                output_filters=512,
                                kh=3,
                                kw=3,
                                sh=1,
                                sw=1,
                                padding='SAME',
                                parameter_update_device=device,
                                weight_decay=weight_decay,
                                initializer=initializer,
                                scope='conv4_1',
                                weight_decay_rate=weight_decay_rate),
                       is_training=is_training,
                       scope='bn4_1',
                       parameter_update_device=device))

        conv4_2 = relu(
            batch_norm(x=conv2d(x=conv4_1,
                                output_filters=512,
                                kh=3,
                                kw=3,
                                sh=1,
                                sw=1,
                                padding='SAME',
                                parameter_update_device=device,
                                weight_decay=weight_decay,
                                initializer=initializer,
                                scope='conv4_2',
                                weight_decay_rate=weight_decay_rate),
                       is_training=is_training,
                       scope='bn4_2',
                       parameter_update_device=device))
        conv4_3 = conv2d(x=conv4_2,
                         output_filters=512,
                         kh=3,
                         kw=3,
                         sh=1,
                         sw=1,
                         padding='SAME',
                         parameter_update_device=device,
                         weight_decay=weight_decay,
                         initializer=initializer,
                         scope='conv4_3',
                         weight_decay_rate=weight_decay_rate)
        if 4 in output_high_level_features:
            features.append(conv4_3)
        conv4_3 = relu(
            batch_norm(x=conv4_3,
                       is_training=is_training,
                       scope='bn4_3',
                       parameter_update_device=device))
        pool4 = tf.nn.max_pool(value=conv4_3,
                               ksize=[1, 2, 2, 1],
                               strides=[1, 2, 2, 1],
                               padding='SAME',
                               name='pool4')

        ## block 5
        conv5_1 = relu(
            batch_norm(x=conv2d(x=pool4,
                                output_filters=512,
                                kh=3,
                                kw=3,
                                sh=1,
                                sw=1,
                                padding='SAME',
                                parameter_update_device=device,
                                weight_decay=weight_decay,
                                initializer=initializer,
                                scope='conv5_1',
                                weight_decay_rate=weight_decay_rate),
                       is_training=is_training,
                       scope='bn5_1',
                       parameter_update_device=device))

        conv5_2 = relu(
            batch_norm(x=conv2d(x=conv5_1,
                                output_filters=512,
                                kh=3,
                                kw=3,
                                sh=1,
                                sw=1,
                                padding='SAME',
                                parameter_update_device=device,
                                weight_decay=weight_decay,
                                initializer=initializer,
                                scope='conv5_2',
                                weight_decay_rate=weight_decay_rate),
                       is_training=is_training,
                       scope='bn5_2',
                       parameter_update_device=device))
        conv5_3 = conv2d(x=conv5_2,
                         output_filters=512,
                         kh=3,
                         kw=3,
                         sh=1,
                         sw=1,
                         padding='SAME',
                         parameter_update_device=device,
                         weight_decay=weight_decay,
                         initializer=initializer,
                         scope='conv5_3',
                         weight_decay_rate=weight_decay_rate)
        if 5 in output_high_level_features:
            features.append(conv5_3)
        conv5_3 = relu(
            batch_norm(x=conv5_3,
                       is_training=is_training,
                       scope='bn5_3',
                       parameter_update_device=device))
        pool5 = tf.nn.max_pool(value=conv5_3,
                               ksize=[1, 2, 2, 1],
                               strides=[1, 2, 2, 1],
                               padding='SAME',
                               name='pool5')
        # block 6
        fc6 = tf.reshape(pool5, [batch_size, -1])
        fc6 = fc(x=fc6,
                 output_size=4096,
                 scope="fc6",
                 weight_decay=weight_decay,
                 initializer=initializer,
                 parameter_update_device=device,
                 weight_decay_rate=weight_decay_rate)
        if 6 in output_high_level_features:
            features.append(fc6)
        fc6 = tf.nn.dropout(x=relu(fc6), keep_prob=keep_prob)

        # block 7
        fc7 = tf.reshape(fc6, [batch_size, -1])
        fc7 = fc(x=fc7,
                 output_size=4096,
                 scope="fc7",
                 weight_decay=weight_decay,
                 initializer=initializer,
                 parameter_update_device=device,
                 weight_decay_rate=weight_decay_rate)
        if 7 in output_high_level_features:
            features.append(fc7)

        return features, return_str
コード例 #4
0
def encoder_resmixernet_framework(images,
                                  is_training,
                                  encoder_device,
                                  scope,initializer,
                                  weight_decay,weight_decay_rate,
                                  residual_at_layer=-1,
                                  residual_connection_mode=None,
                                  reuse=False,
                                  adain_use=False):
    residual_connection_mode = None
    residual_at_layer = -1
    adain_use = False
    full_feature_list = list()

    with tf.variable_scope(tf.get_variable_scope()):
        with tf.device(encoder_device):
            with tf.variable_scope(scope):
                if reuse:
                    tf.get_variable_scope().reuse_variables()
                conv1 = relu(instance_norm(conv2d(x=images,
                                                  output_filters=64,
                                                  kh=7,kw=7, sh=1, sw=1,
                                                  scope="layer%d_conv" % 1,
                                                  parameter_update_device=encoder_device,
                                                  initializer=initializer,
                                                  weight_decay=weight_decay,
                                                  name_prefix=scope,
                                                  weight_decay_rate=weight_decay_rate),
                                           scope="layer%d_in" % 1,
                                           parameter_update_device=encoder_device))
                full_feature_list.append(conv1)

                conv2 = relu(instance_norm(conv2d(x=conv1,
                                                  output_filters=128,
                                                  kh=3, kw=3, sh=2, sw=2,
                                                  scope="layer%d_conv" % 2,
                                                  parameter_update_device=encoder_device,
                                                  initializer=initializer,
                                                  weight_decay=weight_decay,
                                                  name_prefix=scope,
                                                  weight_decay_rate=weight_decay_rate),
                                           scope="layer%d_in" % 2,
                                           parameter_update_device=encoder_device))
                full_feature_list.append(conv2)

                conv3 = relu(instance_norm(conv2d(x=conv2,
                                                  output_filters=256,
                                                  kh=3, kw=3, sh=2, sw=2,
                                                  scope="layer%d_conv" % 3,
                                                  parameter_update_device=encoder_device,
                                                  initializer=initializer,
                                                  weight_decay=weight_decay,
                                                  name_prefix=scope,
                                                  weight_decay_rate=weight_decay_rate),
                                           scope="layer%d_in" % 3,
                                           parameter_update_device=encoder_device))
                full_feature_list.append(conv3)





                return_str = "DenseResNet-Encoder %d Layers" % (len(full_feature_list))

    return conv3, -1, -1, full_feature_list, return_str
コード例 #5
0
ファイル: decoders.py プロジェクト: falconjhc/W-Net
def decoder_adobenet_framework(encoded_layer_list,
                               decoder_input_org,
                               is_training,
                               output_width,
                               output_filters,
                               batch_size,
                               decoder_device,
                               scope,
                               initializer,
                               weight_decay,
                               weight_decay_rate,
                               adain_use,
                               reuse=False,
                               other_info=None):

    residual_connection_mode = None
    residual_at_layer = -1
    adain_use = False
    full_feature_list = list()
    with tf.variable_scope(tf.get_variable_scope()):
        with tf.device(decoder_device):
            with tf.variable_scope(scope):
                if reuse:
                    tf.get_variable_scope().reuse_variables()
                normal_conv_resblock1 = normal_conv_resblock(
                    x=decoder_input_org,
                    initializer=initializer,
                    is_training=is_training,
                    layer=1,
                    kh=3,
                    kw=3,
                    sh=1,
                    sw=1,
                    batch_norm_used=True,
                    weight_decay=weight_decay,
                    weight_decay_rate=weight_decay_rate,
                    scope="layer%d_normal_resblock" % 1,
                    parameter_update_devices=decoder_device)
                full_feature_list.append(normal_conv_resblock1)

                dilated_conv_resblock1 = dilated_conv_resblock(
                    x=normal_conv_resblock1,
                    initializer=initializer,
                    is_training=is_training,
                    layer=2,
                    dilation=2,
                    kh=3,
                    kw=3,
                    batch_norm_used=True,
                    weight_decay=weight_decay,
                    weight_decay_rate=weight_decay_rate,
                    scope="layer%d_dilated_resblock" % 2,
                    parameter_update_devices=decoder_device)
                full_feature_list.append(dilated_conv_resblock1)

                dilated_conv_resblock2 = dilated_conv_resblock(
                    x=dilated_conv_resblock1,
                    initializer=initializer,
                    is_training=is_training,
                    layer=3,
                    dilation=4,
                    kh=3,
                    kw=3,
                    batch_norm_used=True,
                    weight_decay=weight_decay,
                    weight_decay_rate=weight_decay_rate,
                    scope="layer%d_dilated_resblock" % 3,
                    parameter_update_devices=decoder_device)
                full_feature_list.append(dilated_conv_resblock2)

                dilated_conv_1 = relu(
                    batch_norm(x=dilated_conv2d(
                        x=dilated_conv_resblock2,
                        output_filters=128,
                        weight_decay_rate=weight_decay_rate,
                        weight_decay=weight_decay,
                        kh=3,
                        kw=3,
                        dilation=2,
                        initializer=initializer,
                        scope="layer%d_dilated_conv" % 4,
                        parameter_update_device=decoder_device,
                        name_prefix=scope),
                               is_training=is_training,
                               scope="layer%d_bn" % 4,
                               parameter_update_device=decoder_device))

                full_feature_list.append(dilated_conv_1)

                generated_img = tf.nn.tanh(
                    conv2d(x=dilated_conv_1,
                           output_filters=1,
                           weight_decay_rate=weight_decay_rate,
                           weight_decay=weight_decay,
                           kh=3,
                           kw=3,
                           sw=1,
                           sh=1,
                           initializer=initializer,
                           scope="layer%d_normal_conv" % 5,
                           parameter_update_device=decoder_device,
                           name_prefix=scope))
                full_feature_list.append(generated_img)

    return_str = "AdobeNet-Decoder %d Layers" % len(full_feature_list)

    return generated_img, full_feature_list, return_str
コード例 #6
0
def residual_block_implementation(input_list,
                                  residual_num,
                                  residual_at_layer,
                                  is_training,
                                  residual_device,
                                  reuse,
                                  scope,
                                  initializer,
                                  weight_decay,
                                  weight_decay_rate,
                                  style_features,
                                  other_info,
                                  adain_use=False,
                                  adain_preparation_model=None,
                                  debug_mode=False):

    return_str = "Residual %d Blocks" % residual_num
    input_list.reverse()
    with tf.variable_scope(tf.get_variable_scope()):
        with tf.device(residual_device):
            residual_output_list = list()

            if not reuse:
                print(print_separater)
                print(
                    'Adaptive Instance Normalization for Residual Preparations: %s'
                    % adain_preparation_model)
                print(print_separater)

            for ii in range(len(input_list)):
                current_residual_num = residual_num + 2 * ii
                current_residual_input = input_list[ii]
                current_scope = scope + '_onEncDecLyr%d' % (residual_at_layer -
                                                            ii)

                if adain_use:
                    with tf.variable_scope(current_scope):
                        for jj in range(len(style_features)):
                            if int(style_features[jj].shape[2]) == int(
                                    current_residual_input.shape[1]):
                                break

                        for jj in range(int(style_features[ii].shape[0])):
                            if reuse or jj > 0:
                                tf.get_variable_scope().reuse_variables()

                            batch_size = int(
                                style_features[ii][jj, :, :, :, :].shape[0])
                            if batch_size == 1:
                                current_init_residual_input = style_features[
                                    ii][jj, :, :, :, :]
                            else:
                                current_init_residual_input = tf.squeeze(
                                    style_features[ii][jj, :, :, :, :])

                            if adain_preparation_model == 'Multi':
                                # multiple cnn layer built to make the style_conv be incorporated with the dimension of the residual blocks
                                log_input = math.log(
                                    int(current_init_residual_input.shape[3])
                                ) / math.log(2)
                                if math.log(
                                        int(current_init_residual_input.
                                            shape[3])) < math.log(
                                                int(current_residual_input.
                                                    shape[3])):
                                    if np.floor(log_input) < math.log(
                                            int(current_residual_input.shape[3]
                                                )) / math.log(2):
                                        filter_num_start = int(
                                            np.floor(log_input)) + 1
                                    else:
                                        filter_num_start = int(
                                            np.floor(log_input))
                                    filter_num_start = int(
                                        math.pow(2, filter_num_start))
                                elif math.log(
                                        int(current_init_residual_input.
                                            shape[3])) > math.log(
                                                int(current_residual_input.
                                                    shape[3])):
                                    if np.ceil(log_input) > math.log(
                                            int(current_residual_input.shape[3]
                                                )) / math.log(2):
                                        filter_num_start = int(
                                            np.ceil(log_input)) - 1
                                    else:
                                        filter_num_start = int(
                                            np.ceil(log_input))
                                    filter_num_start = int(
                                        math.pow(2, filter_num_start))
                                else:
                                    filter_num_start = int(
                                        current_residual_input.shape[3])
                                filter_num_end = int(
                                    current_residual_input.shape[3])

                                if int(current_init_residual_input.shape[3]
                                       ) == filter_num_end:
                                    continue_build = False
                                    style_conv = current_init_residual_input
                                else:
                                    continue_build = True

                                current_style_conv_input = current_init_residual_input
                                current_output_filter_num = filter_num_start
                                style_cnn_layer_num = 0
                                while continue_build:
                                    style_conv = conv2d(
                                        x=current_style_conv_input,
                                        output_filters=
                                        current_output_filter_num,
                                        scope="conv0_style_layer%d" %
                                        (style_cnn_layer_num + 1),
                                        parameter_update_device=residual_device,
                                        kh=3,
                                        kw=3,
                                        sh=1,
                                        sw=1,
                                        initializer=initializer,
                                        weight_decay=weight_decay,
                                        name_prefix=scope,
                                        weight_decay_rate=weight_decay_rate)
                                    if not (reuse or jj > 0):
                                        print(style_conv)
                                    style_conv = relu(style_conv)

                                    current_style_conv_input = style_conv

                                    if filter_num_start < filter_num_end:
                                        current_output_filter_num = current_output_filter_num * 2
                                    else:
                                        current_output_filter_num = current_output_filter_num / 2
                                    style_cnn_layer_num += 1

                                    if current_output_filter_num > filter_num_end and \
                                            math.log(int(current_init_residual_input.shape[3])) \
                                            < math.log(int(current_residual_input.shape[3])):
                                        current_output_filter_num = filter_num_end
                                    if current_output_filter_num < filter_num_end and \
                                            math.log(int(current_init_residual_input.shape[3])) \
                                            > math.log(int(current_residual_input.shape[3])):
                                        current_output_filter_num = filter_num_end

                                    if int(style_conv.shape[3]
                                           ) == filter_num_end:
                                        continue_build = False

                            elif adain_preparation_model == 'Single':
                                if int(current_init_residual_input.shape[3]
                                       ) == int(
                                           current_residual_input.shape[3]):
                                    style_conv = current_init_residual_input
                                else:
                                    style_conv = conv2d(
                                        x=current_init_residual_input,
                                        output_filters=int(
                                            current_residual_input.shape[3]),
                                        scope="conv0_style_layer0",
                                        parameter_update_device=residual_device,
                                        kh=3,
                                        kw=3,
                                        sh=1,
                                        sw=1,
                                        initializer=initializer,
                                        weight_decay=weight_decay,
                                        name_prefix=scope,
                                        weight_decay_rate=weight_decay_rate)
                                    if not (reuse or jj > 0):
                                        print(style_conv)
                                    style_conv = relu(style_conv)

                            if jj == 0:
                                style_features_new = tf.expand_dims(style_conv,
                                                                    axis=0)
                            else:
                                style_features_new = tf.concat([
                                    style_features_new,
                                    tf.expand_dims(style_conv, axis=0)
                                ],
                                                               axis=0)

                    if (not reuse) and (
                            not math.log(
                                int(current_init_residual_input.shape[3]))
                            == math.log(int(current_residual_input.shape[3]))):
                        print(print_separater)
                else:
                    style_features_new = None

                with tf.variable_scope(current_scope):
                    if reuse:
                        tf.get_variable_scope().reuse_variables()

                    tmp_residual_output_list_on_current_place = list()
                    filter_num = int(current_residual_input.shape[3])
                    for jj in range(current_residual_num):
                        if jj == 0:
                            residual_input = current_residual_input
                        else:
                            if other_info == 'DenseMixer':
                                for kk in range(
                                        len(tmp_residual_output_list_on_current_place
                                            )):
                                    if kk == 0:
                                        residual_input = tmp_residual_output_list_on_current_place[
                                            kk]
                                    else:
                                        residual_input = tf.concat([
                                            residual_input,
                                            tmp_residual_output_list_on_current_place[
                                                kk]
                                        ],
                                                                   axis=3)
                            elif other_info == 'ResidualMixer':
                                residual_input = residual_block_output
                        residual_block_output = \
                            single_resblock(adain_use=adain_use,
                                            is_training=is_training,
                                            residual_device=residual_device,
                                            initializer=initializer,
                                            scope=scope,
                                            weight_decay=weight_decay,
                                            weight_decay_rate=weight_decay_rate,
                                            x=residual_input,
                                            layer=jj+1,
                                            style=style_features_new,
                                            filters=filter_num,
                                            other_info=other_info)
                        tmp_residual_output_list_on_current_place.append(
                            residual_block_output)
                        if jj == current_residual_num - 1:
                            residual_output = residual_block_output

                    residual_output_list.append(residual_output)

    # if (not reuse) and adain_use and (not debug_mode):
    #     print(print_separater)
    #     raw_input("Press enter to continue")
    # print(print_separater)

    return residual_output_list, return_str
コード例 #7
0
def emdnet_mixer_non_adain(generator_device, reuse, scope, initializer,
                           weight_decay, weight_decay_rate,
                           encoded_content_final, content_shortcut_interface,
                           encoded_style_final):

    # mixer
    with tf.variable_scope(tf.get_variable_scope()):
        with tf.device(generator_device):

            with tf.variable_scope(scope):
                if reuse:
                    tf.get_variable_scope().reuse_variables()

                encoded_content_final_squeeze = tf.squeeze(
                    encoded_content_final)
                encoded_style_final_squeeze = tf.squeeze(encoded_style_final)
                if len(encoded_content_final_squeeze.shape.as_list()) == 1:
                    encoded_content_final_squeeze = tf.expand_dims(
                        encoded_content_final_squeeze, axis=0)
                if len(encoded_style_final_squeeze.shape.as_list()) == 1:
                    encoded_style_final_squeeze = tf.expand_dims(
                        encoded_style_final_squeeze, axis=0)

                encoded_content_fc = lrelu(
                    fc(x=encoded_content_final_squeeze,
                       output_size=generator_dim,
                       scope="emd_mixer/content_fc",
                       parameter_update_device=generator_device,
                       initializer=initializer,
                       weight_decay=weight_decay,
                       name_prefix=scope,
                       weight_decay_rate=weight_decay_rate))
                encoded_style_fc = lrelu(
                    fc(x=encoded_style_final_squeeze,
                       output_size=generator_dim,
                       scope="emd_mixer/style_fc",
                       parameter_update_device=generator_device,
                       initializer=initializer,
                       weight_decay=weight_decay,
                       name_prefix=scope,
                       weight_decay_rate=weight_decay_rate))
                mix_content_style = emd_mixer(content=encoded_content_fc,
                                              style=encoded_style_fc,
                                              initializer=initializer,
                                              device=generator_device)
                mixed_fc = relu(
                    fc(x=mix_content_style,
                       output_size=int(encoded_content_final.shape[3]),
                       scope="emd_mixer/mixed_fc",
                       parameter_update_device=generator_device,
                       initializer=initializer,
                       weight_decay=weight_decay,
                       name_prefix=scope,
                       weight_decay_rate=weight_decay_rate))

                mixed_fc = tf.expand_dims(mixed_fc, axis=1)
                mixed_fc = tf.expand_dims(mixed_fc, axis=1)

                valid_encoded_content_shortcut_list = list()
                batch_diff = 0
                batch_diff_count = 0
                for ii in range(len(content_shortcut_interface)):
                    if ii == 0 or ii == len(content_shortcut_interface) - 1:
                        valid_encoded_content_shortcut_list.append(
                            content_shortcut_interface[ii])
                        batch_diff += _calculate_batch_diff(
                            input_feature=content_shortcut_interface[ii])
                        batch_diff_count += 1
                    else:
                        valid_encoded_content_shortcut_list.append(None)
                valid_encoded_content_shortcut_list.reverse()
                batch_diff = batch_diff / batch_diff_count

    return valid_encoded_content_shortcut_list, mixed_fc, batch_diff