Beispiel #1
0
    def discriminate(self, conv, reuse=False, pg=1, t=False, alpha_trans=0.01):

        #dis_as_v = []
        with tf.variable_scope("discriminator") as scope:

            if reuse == True:
                scope.reuse_variables()
            if t:
                conv_iden = avgpool2d(conv)
                #from RGB
                conv_iden = lrelu(conv2d(conv_iden, output_dim= self.get_nf(pg - 2), k_w=1, k_h=1, d_h=1, d_w=1,
                           name='dis_y_rgb_conv_{}'.format(conv_iden.shape[1])))
            # fromRGB
            conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 1), k_w=1, k_h=1, d_w=1, d_h=1, name='dis_y_rgb_conv_{}'.format(conv.shape[1])))
            for i in range(pg - 1):

                conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 1 - i), d_h=1, d_w=1,
                                    name='dis_n_conv_1_{}'.format(conv.shape[1])))
                conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 2 - i), d_h=1, d_w=1,
                                                      name='dis_n_conv_2_{}'.format(conv.shape[1])))
                conv = avgpool2d(conv, 2)
                if i == 0 and t:
                    conv = alpha_trans * conv + (1 - alpha_trans) * conv_iden

            conv = MinibatchstateConcat(conv)
            conv = lrelu(
                conv2d(conv, output_dim=self.get_nf(1), k_w=3, k_h=3, d_h=1, d_w=1, name='dis_n_conv_1_{}'.format(conv.shape[1])))
            conv = lrelu(
                conv2d(conv, output_dim=self.get_nf(1), k_w=4, k_h=4, d_h=1, d_w=1, padding='VALID', name='dis_n_conv_2_{}'.format(conv.shape[1])))
            conv = tf.reshape(conv, [self.batch_size, -1])

            #for D
            output = fully_connect(conv, output_size=1, scope='dis_n_fully')

            return tf.nn.sigmoid(output), output
Beispiel #2
0
    def discriminate(self, conv, t_text_embedding, reuse=False, pg=1, t=False, alpha_trans=0.01):

        #NOTE: discriminate from PGGAN does not use batch norm, add later?

        #dis_as_v = []
        with tf.variable_scope("discriminator") as scope:

            if reuse == True:
                scope.reuse_variables()
            if t:
                conv_iden = avgpool2d(conv)
                #from RGB
                conv_iden = lrelu(conv2d(conv_iden, output_dim= self.get_nf(pg - 2), k_w=1, k_h=1, d_h=1, d_w=1,
                           name='dis_y_rgb_conv_{}'.format(conv_iden.shape[1])))
            # fromRGB
            conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 1), k_w=1, k_h=1, d_w=1, d_h=1, name='dis_y_rgb_conv_{}'.format(conv.shape[1])))
            for i in range(pg - 1):

                conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 1 - i), d_h=1, d_w=1,
                                    name='dis_n_conv_1_{}'.format(conv.shape[1])))
                conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 2 - i), d_h=1, d_w=1,
                                                      name='dis_n_conv_2_{}'.format(conv.shape[1])))
                conv = avgpool2d(conv, 2)
                if i == 0 and t:
                    conv = alpha_trans * conv + (1 - alpha_trans) * conv_iden

            conv = MinibatchstateConcat(conv)
            conv = lrelu(
                conv2d(conv, output_dim=self.get_nf(1), k_w=3, k_h=3, d_h=1, d_w=1, name='dis_n_conv_1_{}'.format(conv.shape[1])))
            conv = lrelu(
                conv2d(conv, output_dim=self.get_nf(1), k_w=4, k_h=4, d_h=1, d_w=1, padding='VALID', name='dis_n_conv_2_{}'.format(conv.shape[1])))


            #ADD TEXT EMBEDDING TO THE NETWORK
            reduced_text_embeddings = lrelu(linear(t_text_embedding, self.tdim, 'd_embedding'))
            reduced_text_embeddings = tf.expand_dims(reduced_text_embeddings,1)
            reduced_text_embeddings = tf.expand_dims(reduced_text_embeddings,2)
            #NOTE: output of prev layer is a 1x1 volume, so we don't tile by 4
            tiled_embeddings = tf.tile(reduced_text_embeddings, [1,1,1,1], name='tiled_embeddings') #last conv layer op should be 4x4
		
            conv_concat = tf.concat( [conv, tiled_embeddings], 3,  name='dis_conv_concat_{}'.format(conv.shape[1]))
            #NOTE: changed the output dims here as compared to text to image code
            conv_new = lrelu((conv2d(conv_concat, output_dim=self.get_nf(1), k_h=1,k_w=1,d_h=1,d_w=1, name = 'dis_conv_new_{}'.format(conv_concat.shape[1])))) #4



            conv_new = tf.reshape(conv_new, [self.batch_size, -1])

            #for D
            output = fully_connect(conv_new, output_size=1, scope='dis_n_fully')

            return tf.nn.sigmoid(output), output
Beispiel #3
0
    def __call__(self,
                 x,
                 is_reuse=False,
                 is_train=True,
                 is_intermediate=False):
        with tf.variable_scope('generator') as scope:
            if is_reuse:
                scope.reuse_variables()

            unit_size = self.img_size[0] // (2**self.layer_n)
            unit_n = self.smallest_hidden_unit_n * (2**(self.layer_n - 1))
            batch_size = int(x.shape[0])
            if is_intermediate:
                intermediate_xs = list()

            with tf.variable_scope('pre'):
                x = linear(x, unit_size * unit_size * unit_n)
                x = tf.reshape(x, (batch_size, unit_size, unit_size, unit_n))
                if self.is_bn:
                    x = batch_norm(x, is_train)
                x = tf.nn.relu(x)
                if is_intermediate:
                    y = avgpool2d(x, unit_size, 1, 'VALID')
                    y = tf.reshape(y, (batch_size, -1))
                    intermediate_xs.append(y)

            for i in range(self.layer_n):
                with tf.variable_scope('layer{}'.format(i)):
                    if i == self.layer_n - 1:
                        unit_n = self.img_dim
                    else:
                        unit_n = self.smallest_hidden_unit_n * (2**(
                            self.layer_n - i - 2))
                    x_shape = x.get_shape().as_list()
                    if self.is_transpose:
                        x = deconv2d(x, [
                            x_shape[0], x_shape[1] * 2, x_shape[1] * 2, unit_n
                        ], self.k_size, 2, 'SAME')
                    else:
                        x = tf.image.resize_bilinear(
                            x, (x_shape[1] * 2, x_shape[2] * 2))
                        x = conv2d(x, unit_n, self.k_size, 1, 'SAME')
                    if i != self.layer_n - 1:
                        if self.is_bn:
                            x = batch_norm(x, is_train)
                        x = tf.nn.relu(x)
                        if is_intermediate:
                            y = tf.reshape(x, (batch_size, -1))
                            intermediate_xs.append(y)
            x = tf.nn.tanh(x)

            if is_intermediate:
                return x, intermediate_xs
            return x
Beispiel #4
0
    def discriminate(self, conv, reuse=False, pg=1, t=False, alpha_trans=0.01):

        #dis_as_v = []
        with tf.variable_scope("discriminator") as scope:

            if reuse == True:
                scope.reuse_variables()

            # transition이 True라면(즉, 해상도를 두배로 할 시에)
            if t:
                conv_iden = avgpool2d(conv)
                #from RGB
                # pg=2라면 get_nf에서 512가 나옴
                # name은 dis_y_rgb_conv_(현재 처리하는 해상도/2)
                conv_iden = lrelu(
                    conv2d(conv_iden,
                           output_dim=self.get_nf(pg - 2),
                           k_w=1,
                           k_h=1,
                           d_h=1,
                           d_w=1,
                           name='dis_y_rgb_conv_{}'.format(
                               conv_iden.shape[1])))
            # fromRGB
            # pg=1일때에 get_nf(pg-1)에서 512 나옴
            # pg=2일때에 get_nf(pg-1)에서 512 나옴
            # pg=3일때에 get_nf(pg-1)에서 256 나옴
            # name은 dis_y_rgb_conv_(현재 처리하는 해상도)
            conv = lrelu(
                conv2d(conv,
                       output_dim=self.get_nf(pg - 1),
                       k_w=1,
                       k_h=1,
                       d_w=1,
                       d_h=1,
                       name='dis_y_rgb_conv_{}'.format(conv.shape[1])))
            # pg=1 일때는 for문 돌지 않음
            # pg=2 일때 i=0나옴
            # pg=3 일때 i=0, 1 나옴
            for i in range(pg - 1):
                # 기본 kernel size는 3x3
                # pg=2이고 i=0일때 get_nf(pg-1-i)에서 512가 나옴
                conv = lrelu(
                    conv2d(conv,
                           output_dim=self.get_nf(pg - 1 - i),
                           d_h=1,
                           d_w=1,
                           name='dis_n_conv_1_{}'.format(conv.shape[1])))
                # pg=2 이고 i=0일때 get_nf(pg-2-i)에서 512가 나나옴
                conv = lrelu(
                    conv2d(conv,
                           output_dim=self.get_nf(pg - 2 - i),
                           d_h=1,
                           d_w=1,
                           name='dis_n_conv_2_{}'.format(conv.shape[1])))
                conv = avgpool2d(conv, 2)
                # 첫번째 block이고 transition이 True라면 Blending시킴
                if i == 0 and t:
                    conv = alpha_trans * conv + (1 - alpha_trans) * conv_iden

            conv = MinibatchstateConcat(conv)
            # 무조건 channel dimension 512
            conv = lrelu(
                conv2d(conv,
                       output_dim=self.get_nf(1),
                       k_w=3,
                       k_h=3,
                       d_h=1,
                       d_w=1,
                       name='dis_n_conv_1_{}'.format(conv.shape[1])))
            # 무조건 channel dimension 512, VALID라서 feature map의  width, height가 줄어듬
            conv = lrelu(
                conv2d(conv,
                       output_dim=self.get_nf(1),
                       k_w=4,
                       k_h=4,
                       d_h=1,
                       d_w=1,
                       padding='VALID',
                       name='dis_n_conv_2_{}'.format(conv.shape[1])))
            # 1차원으로 reshape
            conv = tf.reshape(conv, [self.batch_size, -1])

            #for D
            output = fully_connect(conv, output_size=1, scope='dis_n_fully')

            return tf.nn.sigmoid(output), output
Beispiel #5
0
    def discriminate(self,
                     input_image,
                     reuse=False,
                     model_progressive_depth=1,
                     transition=False,
                     alpha_transition=0.01,
                     input_classes=None):

        with tf.variable_scope("discriminator") as scope:

            if reuse:
                scope.reuse_variables()

            if transition:
                # from RGB, low resolution
                transition_conv = avgpool2d(input_image)
                transition_conv = lrelu(
                    conv2d(transition_conv,
                           output_dim=self.get_filter_num(
                               model_progressive_depth - 2),
                           k_w=1,
                           k_h=1,
                           d_h=1,
                           d_w=1,
                           name='dis_y_rgb_conv_{}'.format(
                               transition_conv.shape[1])))

            convs = []

            # from RGB, high resolution
            convs += [
                lrelu(
                    conv2d(input_image,
                           output_dim=self.get_filter_num(
                               model_progressive_depth - 1),
                           k_w=1,
                           k_h=1,
                           d_w=1,
                           d_h=1,
                           name='dis_y_rgb_conv_{}'.format(
                               input_image.shape[1])))
            ]

            for i in range(model_progressive_depth - 1):

                convs += [
                    lrelu(
                        conv2d(convs[-1],
                               output_dim=self.get_filter_num(
                                   model_progressive_depth - 1 - i),
                               d_h=1,
                               d_w=1,
                               name='dis_n_conv_1_{}'.format(
                                   convs[-1].shape[1])))
                ]

                convs += [
                    lrelu(
                        conv2d(convs[-1],
                               output_dim=self.get_filter_num(
                                   model_progressive_depth - 2 - i),
                               d_h=1,
                               d_w=1,
                               name='dis_n_conv_2_{}'.format(
                                   convs[-1].shape[1])))
                ]
                convs[-1] = avgpool2d(convs[-1], 2)

                if i == 0 and transition:
                    convs[-1] = alpha_transition * convs[-1] + (
                        1 - alpha_transition) * transition_conv

            convs += [minibatch_state_concat(convs[-1])]
            convs[-1] = lrelu(
                conv2d(convs[-1],
                       output_dim=self.get_filter_num(1),
                       k_w=3,
                       k_h=3,
                       d_h=1,
                       d_w=1,
                       name='dis_n_conv_1_{}'.format(convs[-1].shape[1])))

            output = tf.reshape(convs[-1], [self.batch_size, -1])
            discriminate_output = fully_connect(output,
                                                output_size=1,
                                                scope='dis_n_fully')

            return tf.nn.sigmoid(discriminate_output), discriminate_output