示例#1
0
    def decoder(self, z, name='decoder', is_reuse=False):
        with tf.variable_scope(name) as scope:
            if is_reuse is True:
                scope.reuse_variables()
            tf_utils.print_activations(z)

            # 1st hidden layer
            h0_linear = tf_utils.linear(z, self.n_hidden, name='h0_linear')
            h0_tanh = tf_utils.tanh(h0_linear, name='h0_tanh')
            h0_drop = tf.nn.dropout(h0_tanh,
                                    keep_prob=self.keep_prob_tfph,
                                    name='h0_drop')
            tf_utils.print_activations(h0_drop)

            # 2nd hidden layer
            h1_linear = tf_utils.linear(h0_drop,
                                        self.n_hidden,
                                        name='h1_linear')
            h1_elu = tf_utils.elu(h1_linear, name='h1_elu')
            h1_drop = tf.nn.dropout(h1_elu,
                                    keep_prob=self.keep_prob_tfph,
                                    name='h1_drop')
            tf_utils.print_activations(h1_drop)

            # 3rd hidden layer
            h2_linear = tf_utils.linear(h1_drop,
                                        self.output_dim,
                                        name='h2_linear')
            h2_sigmoid = tf_utils.sigmoid(h2_linear, name='h2_sigmoid')
            tf_utils.print_activations(h2_sigmoid)

            output = tf.reshape(h2_sigmoid, [-1, *self.image_size])
            tf_utils.print_activations(output)

        return output
示例#2
0
    def bottleneck_block(self, inputs, filters, train_mode,
                         projection_shortcut, strides, name):
        with tf.compat.v1.variable_scope(name):
            shortcut = inputs
            inputs = tf_utils.relu(inputs, name='relu_0', logger=None)

            # The projection shortcut shouldcome after the first batch norm and ReLU since it perofrms a 1x1 convolution.
            if projection_shortcut is not None:
                shortcut = self.projection_shortcut(inputs=inputs,
                                                    filters_out=filters,
                                                    strides=strides,
                                                    name='conv_projection')

            inputs = self.conv2d_fixed_padding(inputs=inputs,
                                               filters=filters,
                                               kernel_size=3,
                                               strides=strides,
                                               name='conv_0')
            inputs = tf_utils.relu(inputs, name='relu_1', logger=None)
            inputs = self.conv2d_fixed_padding(inputs=inputs,
                                               filters=filters,
                                               kernel_size=3,
                                               strides=1,
                                               name='conv_1')

            output = tf.identity(inputs + shortcut, name=(name + '_output'))
            tf_utils.print_activations(output, logger=None)

            return output
示例#3
0
    def model_g(self, x):
        with tf.variable_scope(self.name, reuse=self.reuse):
            tf_utils.print_activations(x)

            # (N, H, W, C) -> (N, H/2, W/2, 64)
            conv1 = tf_utils.conv2d(x, self.ndf, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME',
                                    name='conv1_conv')
            conv1 = tf_utils.lrelu(conv1, name='conv1_lrelu', is_print=True)

            # (N, H/2, W/2, 64) -> (N, H/4, W/4, 128)
            conv2 = tf_utils.conv_norm_lrelu(conv1, 2 * self.ndf, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME',
                                             name='conv2_conv', ops=self._ops)

            # (N, H/4, W/4, 128) -> (N, H/8, W/8, 256)
            conv3 = tf_utils.conv_norm_lrelu(conv2, 4 * self.ndf, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME',
                                             name='conv3_conv', ops=self._ops)

            # (N, H/8, W/8, 256) -> (N, H/16, W/16, 512)
            conv4 = tf_utils.conv2d(conv3, 8 * self.ndf, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME',
                                    name='conv4_conv', ops=self._ops)

            # (N, H/16, W/16, 512) -> (N, H/16, W/16, 1)
            conv5 = tf_utils.conv2d(conv4, 1, k_h=4, k_w=4, d_h=1, d_w=1, padding='SAME',
                                    name='conv5_conv', is_print=True)

            output = tf.identity(conv5, name='output_without_sigmoid')

            # set reuse=True for next call
            self.reuse = True
            self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)

            return output
示例#4
0
    def forward_network(self, input_img, reuse=False):
        with tf.compat.v1.variable_scope(self.name, reuse=reuse):
            tf_utils.print_activations(input_img, logger=None)
            inputs = self.conv2d_fixed_padding(inputs=input_img, filters=64, kernel_size=7, strides=2, name='conv1')
            inputs = tf_utils.max_pool(inputs, name='3x3_maxpool', ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
                                       logger=None)

            inputs = self.block_layer(inputs=inputs, filters=64, block_fn=self.bottleneck_block, blocks=self.layers[0],
                                      strides=1, train_mode=False, name='block_layer1')
            inputs = self.block_layer(inputs=inputs, filters=128, block_fn=self.bottleneck_block, blocks=self.layers[1],
                                      strides=2, train_mode=False, name='block_layer2')
            inputs = self.block_layer(inputs=inputs, filters=256, block_fn=self.bottleneck_block, blocks=self.layers[2],
                                      strides=2, train_mode=False, name='block_layer3')
            inputs = self.block_layer(inputs=inputs, filters=512, block_fn=self.bottleneck_block, blocks=self.layers[3],
                                      strides=2, train_mode=False, name='block_layer4')

            inputs = tf_utils.relu(inputs, name='before_flatten_relu', logger=None)

            # _, h, w, _ = inputs.get_shape().as_list()
            # inputs = tf_utils.avg_pool(inputs, name='gap', ksize=[1, h, w, 1], strides=[1, 1, 1, 1], logger=self.logger)

            # Flatten & FC1
            inputs = tf_utils.flatten(inputs, name='flatten', logger=None)
            inputs = tf_utils.linear(inputs, 512, name='FC1')
            inputs = tf_utils.relu(inputs, name='FC1_relu', logger=None)

            inputs = tf_utils.linear(inputs, 256, name='FC2')
            inputs = tf_utils.relu(inputs, name='FC2_relu', logger=None)

            logits = tf_utils.linear(inputs, self.num_attribute, name='Out')

            return logits
示例#5
0
    def basicDiscriminator(self, data, name='d_', is_reuse=False):
        with tf.variable_scope(name) as scope:
            if is_reuse is True:
                scope.reuse_variables()
            tf_utils.print_activations(data)

            # from (N, 32, 32, 1) to (N, 16, 16, 64)
            h0_conv = tf_utils.conv2d(data,
                                      self.dis_c[0],
                                      k_h=5,
                                      k_w=5,
                                      name='h0_conv2d')
            h0_lrelu = tf_utils.lrelu(h0_conv, name='h0_lrelu')

            # from (N, 16, 16, 64) to (N, 8, 8, 128)
            h1_conv = tf_utils.conv2d(h0_lrelu,
                                      self.dis_c[1],
                                      k_h=5,
                                      k_w=5,
                                      name='h1_conv2d')
            h1_lrelu = tf_utils.lrelu(h1_conv, name='h1_lrelu')

            # from (N, 8, 8, 128) to (N, 4, 4, 256)
            h2_conv = tf_utils.conv2d(h1_lrelu,
                                      self.dis_c[2],
                                      k_h=5,
                                      k_w=5,
                                      name='h2_conv2d')
            h2_lrelu = tf_utils.lrelu(h2_conv, name='h2_lrelu')

            # from (N, 4, 4, 256) to (N, 4096) and to (N, 1)
            h2_flatten = flatten(h2_lrelu)
            h3_linear = tf_utils.linear(h2_flatten, 1, name='h3_linear')

            return tf.nn.sigmoid(h3_linear), h3_linear
示例#6
0
    def basicGenerator(self, data, name='g_'):
        with tf.variable_scope(name):
            data_flatten = flatten(data)
            tf_utils.print_activations(data_flatten)

            # from (N, 128) to (N, 4, 4, 256)
            h0_linear = tf_utils.linear(data_flatten,
                                        self.gen_c[0],
                                        name='h0_linear')
            if self.flags.dataset == 'cifar10':
                h0_linear = tf.reshape(h0_linear, [
                    tf.shape(h0_linear)[0], 4, 4,
                    int(self.gen_c[0] / (4 * 4))
                ])
                h0_linear = tf_utils.norm(h0_linear,
                                          _type='batch',
                                          _ops=self.gen_train_ops,
                                          name='h0_norm')
            h0_relu = tf.nn.relu(h0_linear, name='h0_relu')
            h0_reshape = tf.reshape(
                h0_relu,
                [tf.shape(h0_relu)[0], 4, 4,
                 int(self.gen_c[0] / (4 * 4))])

            # from (N, 4, 4, 256) to (N, 8, 8, 128)
            h1_deconv = tf_utils.deconv2d(h0_reshape,
                                          self.gen_c[1],
                                          k_h=5,
                                          k_w=5,
                                          name='h1_deconv2d')
            if self.flags.dataset == 'cifar10':
                h1_deconv = tf_utils.norm(h1_deconv,
                                          _type='batch',
                                          _ops=self.gen_train_ops,
                                          name='h1_norm')
            h1_relu = tf.nn.relu(h1_deconv, name='h1_relu')

            # from (N, 8, 8, 128) to (N, 16, 16, 64)
            h2_deconv = tf_utils.deconv2d(h1_relu,
                                          self.gen_c[2],
                                          k_h=5,
                                          k_w=5,
                                          name='h2_deconv2d')
            if self.flags.dataset == 'cifar10':
                h2_deconv = tf_utils.norm(h2_deconv,
                                          _type='batch',
                                          _ops=self.gen_train_ops,
                                          name='h2_norm')
            h2_relu = tf.nn.relu(h2_deconv, name='h2_relu')

            # from (N, 16, 16, 64) to (N, 32, 32, 1)
            output = tf_utils.deconv2d(h2_relu,
                                       self.image_size[2],
                                       k_h=5,
                                       k_w=5,
                                       name='h3_deconv2d')

            return tf_utils.tanh(output)
示例#7
0
    def __call__(self, x, mode=1):
        with tf.variable_scope(self.name, reuse=self.reuse):
            x = tf.concat([x, x, x], axis=-1, name='concat')
            tf_utils.print_activations(x)

            # conv1
            relu1_1 = self.conv_layer(x, 'conv1_1', trainable=False)
            relu1_2 = self.conv_layer(relu1_1, 'conv1_2', trainable=False)
            pool_1 = tf_utils.max_pool_2x2(relu1_2, name='max_pool_1')
            tf_utils.print_activations(pool_1)

            # conv2
            relu2_1 = self.conv_layer(pool_1, 'conv2_1', trainable=False)
            relu2_2 = self.conv_layer(relu2_1, 'conv2_2', trainable=False)
            pool_2 = tf_utils.max_pool_2x2(relu2_2, name='max_pool_2')
            tf_utils.print_activations(pool_2)

            # conv3
            relu3_1 = self.conv_layer(pool_2, 'conv3_1', trainable=False)
            relu3_2 = self.conv_layer(relu3_1, 'conv3_2', trainable=False)
            relu3_3 = self.conv_layer(relu3_2, 'conv3_3', trainable=False)
            pool_3 = tf_utils.max_pool_2x2(relu3_3, name='max_pool_3')
            tf_utils.print_activations(pool_3)

            # conv4
            relu4_1 = self.conv_layer(pool_3, 'conv4_1', trainable=False)
            relu4_2 = self.conv_layer(relu4_1, 'conv4_2', trainable=False)
            relu4_3 = self.conv_layer(relu4_2, 'conv4_3', trainable=False)
            pool_4 = tf_utils.max_pool_2x2(relu4_3, name='max_pool_4')
            tf_utils.print_activations(pool_4)

            # conv5
            relu5_1 = self.conv_layer(pool_4, 'conv5_1', trainable=False)
            relu5_2 = self.conv_layer(relu5_1, 'conv5_2', trainable=False)
            relu5_3 = self.conv_layer(relu5_2, 'conv5_3', trainable=False)

            # set reuse=True for next call
            self.reuse = True

            if mode == 1:
                outputs = [relu1_2]
            elif mode == 2:
                outputs = [relu1_2, relu2_2]
            elif mode == 3:
                outputs = [relu1_2, relu2_2, relu3_3]
            elif mode == 4:
                outputs = [relu1_2, relu2_2, relu3_3, relu4_3]
            elif mode == 5:
                outputs = [relu1_2, relu2_2, relu3_3, relu4_3, relu5_3]
            else:
                raise NotImplementedError

            return outputs
示例#8
0
    def __call__(self, x):
        with tf.variable_scope(self.name, reuse=self.reuse):
            tf_utils.print_activations(x)

            # (N, H, W, C) -> (N, H, W, 64)
            conv1 = tf_utils.padding2d(x, p_h=3, p_w=3, pad_type='REFLECT', name='conv1_padding')
            conv1 = tf_utils.conv2d(conv1, self.ngf, k_h=7, k_w=7, d_h=1, d_w=1, padding='VALID',
                                    name='conv1_conv')
            conv1 = tf_utils.norm(conv1, _type='instance', _ops=self._ops, name='conv1_norm')
            conv1 = tf_utils.relu(conv1, name='conv1_relu', is_print=True)

            # (N, H, W, 64)  -> (N, H/2, W/2, 128)
            conv2 = tf_utils.conv2d(conv1, 2*self.ngf, k_h=3, k_w=3, d_h=2, d_w=2, padding='SAME',
                                    name='conv2_conv')
            conv2 = tf_utils.norm(conv2, _type='instance', _ops=self._ops, name='conv2_norm',)
            conv2 = tf_utils.relu(conv2, name='conv2_relu', is_print=True)

            # (N, H/2, W/2, 128) -> (N, H/4, W/4, 256)
            conv3 = tf_utils.conv2d(conv2, 4*self.ngf, k_h=3, k_w=3, d_h=2, d_w=2, padding='SAME',
                                    name='conv3_conv')
            conv3 = tf_utils.norm(conv3, _type='instance', _ops=self._ops, name='conv3_norm',)
            conv3 = tf_utils.relu(conv3, name='conv3_relu', is_print=True)

            # (N, H/4, W/4, 256) -> (N, H/4, W/4, 256)
            if (self.image_size[0] <= 128) and (self.image_size[1] <= 128):
                # use 6 residual blocks for 128x128 images
                res_out = tf_utils.n_res_blocks(conv3, num_blocks=6, is_print=True)
            else:
                # use 9 blocks for higher resolution
                res_out = tf_utils.n_res_blocks(conv3, num_blocks=9, is_print=True)

            # (N, H/4, W/4, 256) -> (N, H/2, W/2, 128)
            conv4 = tf_utils.deconv2d(res_out, 2*self.ngf, name='conv4_deconv2d')
            conv4 = tf_utils.norm(conv4, _type='instance', _ops=self._ops, name='conv4_norm')
            conv4 = tf_utils.relu(conv4, name='conv4_relu', is_print=True)

            # (N, H/2, W/2, 128) -> (N, H, W, 64)
            conv5 = tf_utils.deconv2d(conv4, self.ngf, name='conv5_deconv2d')
            conv5 = tf_utils.norm(conv5, _type='instance', _ops=self._ops, name='conv5_norm')
            conv5 = tf_utils.relu(conv5, name='conv5_relu', is_print=True)

            # (N, H, W, 64) -> (N, H, W, 3)
            conv6 = tf_utils.padding2d(conv5, p_h=3, p_w=3, pad_type='REFLECT', name='output_padding')
            conv6 = tf_utils.conv2d(conv6, self.image_size[2], k_h=7, k_w=7, d_h=1, d_w=1,
                                    padding='VALID', name='output_conv')
            output = tf_utils.tanh(conv6, name='output_tanh', is_print=True)

            # set reuse=True for next call
            self.reuse = True
            self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)

            return output
示例#9
0
    def resnetDiscriminator(self, data, name='d_', is_reuse=False):
        with tf.variable_scope(name) as scope:
            if is_reuse is True:
                scope.reuse_variables()
            tf_utils.print_activations(data)

            # (N, 64, 64, 64)
            conv_0 = tf_utils.conv2d(data,
                                     output_dim=self.dis_c[0],
                                     k_h=3,
                                     k_w=3,
                                     d_h=1,
                                     d_w=1,
                                     name='conv_0')
            # (N, 32, 32, 128)
            resblock_1 = tf_utils.res_block_v2(conv_0,
                                               self.dis_c[1],
                                               filter_size=3,
                                               _ops=self.dis_train_ops,
                                               norm_='layer',
                                               resample='down',
                                               name='res_block_1')
            # (N, 16, 16, 256)
            resblock_2 = tf_utils.res_block_v2(resblock_1,
                                               self.dis_c[2],
                                               filter_size=3,
                                               _ops=self.dis_train_ops,
                                               norm_='layer',
                                               resample='down',
                                               name='res_block_2')
            # (N, 8, 8, 512)
            resblock_3 = tf_utils.res_block_v2(resblock_2,
                                               self.dis_c[3],
                                               filter_size=3,
                                               _ops=self.dis_train_ops,
                                               norm_='layer',
                                               resample='down',
                                               name='res_block_3')
            # (N, 4, 4, 512)
            resblock_4 = tf_utils.res_block_v2(resblock_3,
                                               self.dis_c[4],
                                               filter_size=3,
                                               _ops=self.dis_train_ops,
                                               norm_='layer',
                                               resample='down',
                                               name='res_block_4')
            # (N, 4*4*512)
            flatten_5 = flatten(resblock_4)
            output = tf_utils.linear(flatten_5, 1, name='output')

            return tf.nn.sigmoid(output), output
示例#10
0
    def __call__(self, x):
        with tf.variable_scope(self.name, reuse=self.reuse):
            tf_utils.print_activations(x)

            # 200 -> 100
            h0_conv2d = tf_utils.conv2d(x, self.dis_c[0], name='h0_conv2d')
            h0_lrelu = tf_utils.lrelu(h0_conv2d, name='h0_lrelu')

            # 100 -> 50
            h1_conv2d = tf_utils.conv2d(h0_lrelu,
                                        self.dis_c[1],
                                        name='h1_conv2d')
            h1_batchnorm = tf_utils.batch_norm(h1_conv2d,
                                               name='h1_batchnorm',
                                               _ops=self._ops)
            h1_lrelu = tf_utils.lrelu(h1_batchnorm, name='h1_lrelu')

            # 50 -> 25
            h2_conv2d = tf_utils.conv2d(h1_lrelu,
                                        self.dis_c[2],
                                        name='h2_conv2d')
            h2_batchnorm = tf_utils.batch_norm(h2_conv2d,
                                               name='h2_batchnorm',
                                               _ops=self._ops)
            h2_lrelu = tf_utils.lrelu(h2_batchnorm, name='h2_lrelu')

            # 25 -> 13
            h3_conv2d = tf_utils.conv2d(h2_lrelu,
                                        self.dis_c[3],
                                        name='h3_conv2d')
            h3_batchnorm = tf_utils.batch_norm(h3_conv2d,
                                               name='h3_batchnorm',
                                               _ops=self._ops)
            h3_lrelu = tf_utils.lrelu(h3_batchnorm, name='h3_lrelu')

            # Patch GAN: 13 -> 13
            output = tf_utils.conv2d(h3_lrelu,
                                     self.dis_c[4],
                                     k_h=3,
                                     k_w=3,
                                     d_h=1,
                                     d_w=1,
                                     name='output_conv2d')

            # set reuse=True for next call
            self.reuse = True
            self.variables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)

            return output
示例#11
0
    def __call__(self, x):
        with tf.variable_scope(self.name, reuse=self.reuse):
            tf_utils.print_activations(x)

            # conv: (N, H, W, 3) -> (N, H/2, W/2, 64)
            output = tf_utils.conv2d(x,
                                     self.ndf,
                                     k_h=4,
                                     k_w=4,
                                     d_h=2,
                                     d_w=2,
                                     padding='SAME',
                                     name='conv0_conv2d')
            output = tf_utils.lrelu(output, name='conv0_lrelu', is_print=True)

            for idx, hidden_dim in enumerate(self.hidden_dims[1:]):
                # conv: (N, H/2, W/2, C) -> (N, H/4, W/4, C/2)
                output = tf_utils.conv2d(output,
                                         hidden_dim,
                                         k_h=4,
                                         k_w=4,
                                         d_h=2,
                                         d_w=2,
                                         padding='SAME',
                                         name='conv{}_conv2d'.format(idx + 1))
                output = tf_utils.norm(output,
                                       _type=self.norm,
                                       _ops=self._ops,
                                       name='conv{}_norm'.format(idx + 1))
                output = tf_utils.lrelu(output,
                                        name='conv{}_lrelu'.format(idx + 1),
                                        is_print=True)

            # conv: (N, H/16, W/16, 512) -> (N, H/16, W/16, 1)
            output = tf_utils.conv2d(output,
                                     1,
                                     k_h=4,
                                     k_w=4,
                                     d_h=1,
                                     d_w=1,
                                     padding='SAME',
                                     name='conv4_conv2d')

            # set reuse=True for next call
            self.reuse = True
            self.variables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)

            return tf_utils.sigmoid(output), output
示例#12
0
    def network(self, inputs, name=None):
        with tf.variable_scope(name):
            tf_utils.print_activations(inputs)

            # input of main reccurent layers
            output = tf_utils.conv2d_mask(inputs,
                                          2 * self.hidden_dims, [7, 7],
                                          mask_type="A",
                                          name='inputConv1')

            # main recurrent layers
            if self.flags.model == 'pixelcnn':
                for idx in range(self.recurrent_length):
                    output = tf_utils.conv2d_mask(
                        output,
                        self.hidden_dims, [3, 3],
                        mask_type="B",
                        name='mainConv{}'.format(idx + 2))
                    output = tf_utils.relu(output,
                                           name='mainRelu{}'.format(idx + 2))
            elif self.flags.model == 'diagonal_bilstm':
                for idx in range(self.recurrent_length):
                    output = self.diagonal_bilstm(output,
                                                  name='BiLSTM{}'.format(idx +
                                                                         2))
            elif self.flags.model == 'row_lstm':
                raise NotImplementedError
            else:
                raise NotImplementedError

            # output recurrent layers
            for idx in range(self.out_recurrent_length):
                output = tf_utils.conv2d_mask(output,
                                              self.hidden_dims, [1, 1],
                                              mask_type="B",
                                              name='outputConv{}'.format(idx +
                                                                         1))
                output = tf_utils.relu(output,
                                       name='outputRelu{}'.format(idx + 1))

            # TODO: for color images, implement a 256-way softmax for each RGB channel here
            output = tf_utils.conv2d_mask(output,
                                          self.img_size[2], [1, 1],
                                          mask_type="B",
                                          name='outputConv3')
            # output = tf_utils.sigmoid(output_logits, name='output_sigmoid')

            return tf_utils.sigmoid(output), output
    def __init__(self, input_dim, output_dim=1, optimizer=None, use_dropout=True, lr=0.001, random_seed=123,
                 is_train=True, log_dir=None, name=None):
        self.name = name
        self.is_train = is_train
        self.log_dir = log_dir
        self.cur_lr = None
        self.logger, self.file_handler, self.stream_handler = utils.init_logger(log_dir=self.log_dir,
                                                                                name=self.name,
                                                                                is_train=self.is_train)
        with tf.variable_scope(self.name):
            # Placeholders for inputs
            self.X = tf.placeholder(dtype=tf.float32, shape=[None, input_dim], name='X')
            self.y = tf.placeholder(dtype=tf.float32, shape=[None, output_dim], name='y')
            self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
            tf_utils.print_activations(self.X, logger=self.logger if self.is_train else None)

            # Placeholders for TensorBoard
            self.train_acc = tf.placeholder(tf.float32, name='train_acc')
            self.val_acc = tf.placeholder(tf.float32, name='val_acc')

            net = self.X
            if use_dropout:
                net = tf_utils.dropout(x=net,
                                       keep_prob=self.keep_prob,
                                       seed=random_seed,
                                       name='dropout',
                                       logger=self.logger if self.is_train else None)

            # Network, loss, and optimizer
            self.y_pred = tf_utils.linear(net, output_size=output_dim)
            tf_utils.print_activations(self.y_pred, logger=self.logger if self.is_train else None)
            self.loss = tf.math.reduce_mean(tf.nn.l2_loss(self.y_pred - self.y))
            self.train_op, self.cur_lr = optimizer_fn(optimizer, lr=lr, loss=self.loss, name=self.name)

            # Accuracy etc
            self.y_pred_round = tf.math.round(x=self.y_pred, name='rounded_pred')
            accuracy = tf.equal(tf.cast(x=self.y_pred_round, dtype=tf.int32), tf.cast(x=self.y, dtype=tf.int32))
            self.accuracy = tf.reduce_mean(tf.cast(x=accuracy, dtype=tf.float32)) * 100.

        self._tensorboard()
        tf_utils.show_all_variables(logger=self.logger if self.is_train else None)
示例#14
0
    def __call__(self, x):
        with tf.variable_scope(self.name, reuse=self.reuse):
            x = tf.concat([x, x, x], axis=-1, name='concat')
            tf_utils.print_activations(x)

            # conv1
            relu1_1 = self.conv_layer(x, 'conv1_1', trainable=False)
            relu1_2 = self.conv_layer(relu1_1, 'conv1_2', trainable=False)
            pool_1 = tf_utils.max_pool_2x2(relu1_2, name='max_pool_1')
            tf_utils.print_activations(pool_1)

            # conv2
            relu2_1 = self.conv_layer(pool_1, 'conv2_1', trainable=False)
            relu2_2 = self.conv_layer(relu2_1, 'conv2_2', trainable=False)
            pool_2 = tf_utils.max_pool_2x2(relu2_2, name='max_pool_2')
            tf_utils.print_activations(pool_2)

            # conv3
            relu3_1 = self.conv_layer(pool_2, 'conv3_1', trainable=False)
            relu3_2 = self.conv_layer(relu3_1, 'conv3_2', trainable=False)
            relu3_3 = self.conv_layer(relu3_2, 'conv3_3', trainable=False)
            pool_3 = tf_utils.max_pool_2x2(relu3_3, name='max_pool_3')
            tf_utils.print_activations(pool_3)

            # conv4
            relu4_1 = self.conv_layer(pool_3, 'conv4_1', trainable=False)
            relu4_2 = self.conv_layer(relu4_1, 'conv4_2', trainable=False)
            relu4_3 = self.conv_layer(relu4_2, 'conv4_3', trainable=False)
            pool_4 = tf_utils.max_pool_2x2(relu4_3, name='max_pool_4')
            tf_utils.print_activations(pool_4)

            # conv5
            relu5_1 = self.conv_layer(pool_4, 'conv5_1', trainable=False)
            relu5_2 = self.conv_layer(relu5_1, 'conv5_2', trainable=False)
            relu5_3 = self.conv_layer(relu5_2, 'conv5_3', trainable=False)

            # set reuse=True for next call
            self.reuse = True

            return relu5_3
示例#15
0
    def conv_layer(self, bottom, name, trainable=False):
        with tf.variable_scope(name):
            w = self.get_conv_weight(name)
            b = self.get_bias(name)
            conv_weights = tf.get_variable(
                "W",
                shape=w.shape,
                initializer=tf.constant_initializer(w),
                trainable=trainable)
            conv_biases = tf.get_variable(
                "b",
                shape=b.shape,
                initializer=tf.constant_initializer(b),
                trainable=trainable)

            conv = tf.nn.conv2d(bottom,
                                conv_weights, [1, 1, 1, 1],
                                padding='SAME')
            bias = tf.nn.bias_add(conv, conv_biases)
            relu = tf.nn.relu(bias)
            tf_utils.print_activations(relu)

        return relu
    def __init__(self, input_dim, output_dim=[1000, 1000, 10], optimizer=None, use_dropout=True, lr=0.001,
                 weight_decay=1e-4, random_seed=123, is_train=True, log_dir=None, name=None):
        self.name = name
        self.is_train = is_train
        self.log_dir = log_dir
        self.cur_lr = None
        self.logger, self.file_handler, self.stream_handler = utils.init_logger(log_dir=self.log_dir,
                                                                                name=self.name,
                                                                                is_train=self.is_train)

        with tf.variable_scope(self.name):
            # Placeholders for inputs
            self.X = tf.placeholder(dtype=tf.float32, shape=[None, input_dim], name='X')
            tf_utils.print_activations(self.X, logger=self.logger if self.is_train else None)
            self.y = tf.placeholder(dtype=tf.float32, shape=[None, output_dim[-1]], name='y')
            self.y_cls = tf.math.argmax(input=self.y, axis=1)
            self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')

            # Placeholders for TensorBoard
            self.train_acc = tf.placeholder(tf.float32, name='train_acc')
            self.val_acc = tf.placeholder(tf.float32, name='val_acc')

            net = self.X
            for idx in range(len(output_dim) - 1):
                net = tf_utils.linear(x=net,
                                      output_size=output_dim[idx],
                                      name='fc'+str(idx),
                                      logger=self.logger if self.is_train else None)

                if use_dropout:
                    net = tf_utils.dropout(x=net,
                                           keep_prob=self.keep_prob,
                                           seed=random_seed,
                                           name='dropout'+str(idx),
                                           logger=self.logger if self.is_train else None)

                net = tf_utils.relu(x=net,
                                    name='relu'+str(idx),
                                    logger=self.logger if self.is_train else None)

            # Last predict layer
            self.y_pred = tf_utils.linear(net, output_size=output_dim[-1], name='last_fc')
            tf_utils.print_activations(self.y_pred, logger=self.logger if self.is_train else None)

            # Loss = data loss + regularization term
            self.data_loss = tf.math.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=self.y_pred, labels=self.y))
            self.reg_term = weight_decay * tf.reduce_sum(
                [tf.nn.l2_loss(weight) for weight in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)])
            self.loss = self.data_loss + self.reg_term

            # Optimizer
            self.train_op, self.cur_lr = optimizer_fn(optimizer, lr=lr, loss=self.loss, name=self.name)

            # Accuracy etc
            self.y_pred_cls = tf.math.argmax(input=self.y_pred, axis=1)
            correct_prediction = tf.math.equal(self.y_pred_cls, self.y_cls)
            self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype=tf.float32)) * 100.

        self._tensorboard()
        tf_utils.show_all_variables(logger=self.logger if self.is_train else None)
示例#17
0
    def __call__(self, x):
        with tf.variable_scope(self.name, reuse=self.reuse):
            tf_utils.print_activations(x)

            # (300, 200) -> (150, 100)
            e0_conv2d = tf_utils.conv2d(x, self.gen_c[0], name='e0_conv2d')
            e0_lrelu = tf_utils.lrelu(e0_conv2d, name='e0_lrelu')

            # (150, 100) -> (75, 50)
            e1_conv2d = tf_utils.conv2d(e0_lrelu,
                                        self.gen_c[1],
                                        name='e1_conv2d')
            e1_batchnorm = tf_utils.batch_norm(e1_conv2d,
                                               name='e1_batchnorm',
                                               _ops=self._ops)
            e1_lrelu = tf_utils.lrelu(e1_batchnorm, name='e1_lrelu')

            # (75, 50) -> (38, 25)
            e2_conv2d = tf_utils.conv2d(e1_lrelu,
                                        self.gen_c[2],
                                        name='e2_conv2d')
            e2_batchnorm = tf_utils.batch_norm(e2_conv2d,
                                               name='e2_batchnorm',
                                               _ops=self._ops)
            e2_lrelu = tf_utils.lrelu(e2_batchnorm, name='e2_lrelu')

            # (38, 25) -> (19, 13)
            e3_conv2d = tf_utils.conv2d(e2_lrelu,
                                        self.gen_c[3],
                                        name='e3_conv2d')
            e3_batchnorm = tf_utils.batch_norm(e3_conv2d,
                                               name='e3_batchnorm',
                                               _ops=self._ops)
            e3_lrelu = tf_utils.lrelu(e3_batchnorm, name='e3_lrelu')

            # (19, 13) -> (10, 7)
            e4_conv2d = tf_utils.conv2d(e3_lrelu,
                                        self.gen_c[4],
                                        name='e4_conv2d')
            e4_batchnorm = tf_utils.batch_norm(e4_conv2d,
                                               name='e4_batchnorm',
                                               _ops=self._ops)
            e4_lrelu = tf_utils.lrelu(e4_batchnorm, name='e4_lrelu')

            # (10, 7) -> (5, 4)
            e5_conv2d = tf_utils.conv2d(e4_lrelu,
                                        self.gen_c[5],
                                        name='e5_conv2d')
            e5_batchnorm = tf_utils.batch_norm(e5_conv2d,
                                               name='e5_batchnorm',
                                               _ops=self._ops)
            e5_lrelu = tf_utils.lrelu(e5_batchnorm, name='e5_lrelu')

            # (5, 4) -> (3, 2)
            e6_conv2d = tf_utils.conv2d(e5_lrelu,
                                        self.gen_c[6],
                                        name='e6_conv2d')
            e6_batchnorm = tf_utils.batch_norm(e6_conv2d,
                                               name='e6_batchnorm',
                                               _ops=self._ops)
            e6_lrelu = tf_utils.lrelu(e6_batchnorm, name='e6_lrelu')

            # (3, 2) -> (2, 1)
            e7_conv2d = tf_utils.conv2d(e6_lrelu,
                                        self.gen_c[7],
                                        name='e7_conv2d')
            e7_batchnorm = tf_utils.batch_norm(e7_conv2d,
                                               name='e7_batchnorm',
                                               _ops=self._ops)
            e7_relu = tf_utils.relu(e7_batchnorm, name='e7_relu')

            # (2, 1) -> (4, 2)
            d0_deconv = tf_utils.deconv2d(e7_relu,
                                          self.gen_c[8],
                                          name='d0_deconv2d')
            shapeA = e6_conv2d.get_shape().as_list()[1]
            shapeB = d0_deconv.get_shape().as_list()[1] - e6_conv2d.get_shape(
            ).as_list()[1]
            # (4, 2) -> (3, 2)
            d0_split, _ = tf.split(d0_deconv, [shapeA, shapeB],
                                   axis=1,
                                   name='d0_split')
            tf_utils.print_activations(d0_split)
            d0_batchnorm = tf_utils.batch_norm(d0_split,
                                               name='d0_batchnorm',
                                               _ops=self._ops)
            d0_drop = tf.nn.dropout(d0_batchnorm,
                                    keep_prob=0.5,
                                    name='d0_dropout')
            d0_concat = tf.concat([d0_drop, e6_batchnorm],
                                  axis=3,
                                  name='d0_concat')
            d0_relu = tf_utils.relu(d0_concat, name='d0_relu')

            # (3, 2) -> (6, 4)
            d1_deconv = tf_utils.deconv2d(d0_relu,
                                          self.gen_c[9],
                                          name='d1_deconv2d')
            # (6, 4) -> (5, 4)
            shapeA = e5_batchnorm.get_shape().as_list()[1]
            shapeB = d1_deconv.get_shape().as_list(
            )[1] - e5_batchnorm.get_shape().as_list()[1]
            d1_split, _ = tf.split(d1_deconv, [shapeA, shapeB],
                                   axis=1,
                                   name='d1_split')
            tf_utils.print_activations(d1_split)
            d1_batchnorm = tf_utils.batch_norm(d1_split,
                                               name='d1_batchnorm',
                                               _ops=self._ops)
            d1_drop = tf.nn.dropout(d1_batchnorm,
                                    keep_prob=0.5,
                                    name='d1_dropout')
            d1_concat = tf.concat([d1_drop, e5_batchnorm],
                                  axis=3,
                                  name='d1_concat')
            d1_relu = tf_utils.relu(d1_concat, name='d1_relu')

            # (5, 4) -> (10, 8)
            d2_deconv = tf_utils.deconv2d(d1_relu,
                                          self.gen_c[10],
                                          name='d2_deconv2d')
            # (10, 8) -> (10, 7)
            shapeA = e4_batchnorm.get_shape().as_list()[2]
            shapeB = d2_deconv.get_shape().as_list(
            )[2] - e4_batchnorm.get_shape().as_list()[2]
            d2_split, _ = tf.split(d2_deconv, [shapeA, shapeB],
                                   axis=2,
                                   name='d2_split')
            tf_utils.print_activations(d2_split)
            d2_batchnorm = tf_utils.batch_norm(d2_split,
                                               name='d2_batchnorm',
                                               _ops=self._ops)
            d2_drop = tf.nn.dropout(d2_batchnorm,
                                    keep_prob=0.5,
                                    name='d2_dropout')
            d2_concat = tf.concat([d2_drop, e4_batchnorm],
                                  axis=3,
                                  name='d2_concat')
            d2_relu = tf_utils.relu(d2_concat, name='d2_relu')

            # (10, 7) -> (20, 14)
            d3_deconv = tf_utils.deconv2d(d2_relu,
                                          self.gen_c[11],
                                          name='d3_deconv2d')
            # (20, 14) -> (19, 14)
            shapeA = e3_batchnorm.get_shape().as_list()[1]
            shapeB = d3_deconv.get_shape().as_list(
            )[1] - e3_batchnorm.get_shape().as_list()[1]
            d3_split_1, _ = tf.split(d3_deconv, [shapeA, shapeB],
                                     axis=1,
                                     name='d3_split_1')
            tf_utils.print_activations(d3_split_1)
            # (19, 14) -> (19, 13)
            shapeA = e3_batchnorm.get_shape().as_list()[2]
            shapeB = d3_split_1.get_shape().as_list(
            )[2] - e3_batchnorm.get_shape().as_list()[2]
            d3_split_2, _ = tf.split(d3_split_1, [shapeA, shapeB],
                                     axis=2,
                                     name='d3_split_2')
            tf_utils.print_activations(d3_split_2)
            d3_batchnorm = tf_utils.batch_norm(d3_split_2,
                                               name='d3_batchnorm',
                                               _ops=self._ops)
            d3_concat = tf.concat([d3_batchnorm, e3_batchnorm],
                                  axis=3,
                                  name='d3_concat')
            d3_relu = tf_utils.relu(d3_concat, name='d3_relu')

            # (19, 13) -> (38, 26)
            d4_deconv = tf_utils.deconv2d(d3_relu,
                                          self.gen_c[12],
                                          name='d4_deconv2d')
            # (38, 26) -> (38, 25)
            shapeA = e2_batchnorm.get_shape().as_list()[2]
            shapeB = d4_deconv.get_shape().as_list(
            )[2] - e2_batchnorm.get_shape().as_list()[2]
            d4_split, _ = tf.split(d4_deconv, [shapeA, shapeB],
                                   axis=2,
                                   name='d4_split')
            tf_utils.print_activations(d4_split)
            d4_batchnorm = tf_utils.batch_norm(d4_split,
                                               name='d4_batchnorm',
                                               _ops=self._ops)
            d4_concat = tf.concat([d4_batchnorm, e2_batchnorm],
                                  axis=3,
                                  name='d4_concat')
            d4_relu = tf_utils.relu(d4_concat, name='d4_relu')

            # (38, 25) -> (76, 50)
            d5_deconv = tf_utils.deconv2d(d4_relu,
                                          self.gen_c[13],
                                          name='d5_deconv2d')
            # (76, 50) -> (75, 50)
            shapeA = e1_batchnorm.get_shape().as_list()[1]
            shapeB = d5_deconv.get_shape().as_list(
            )[1] - e1_batchnorm.get_shape().as_list()[1]
            d5_split, _ = tf.split(d5_deconv, [shapeA, shapeB],
                                   axis=1,
                                   name='d5_split')
            tf_utils.print_activations(d5_split)
            d5_batchnorm = tf_utils.batch_norm(d5_split,
                                               name='d5_batchnorm',
                                               _ops=self._ops)
            d5_concat = tf.concat([d5_batchnorm, e1_batchnorm],
                                  axis=3,
                                  name='d5_concat')
            d5_relu = tf_utils.relu(d5_concat, name='d5_relu')

            # (75, 50) -> (150, 100)
            d6_deconv = tf_utils.deconv2d(d5_relu,
                                          self.gen_c[14],
                                          name='d6_deconv2d')
            d6_batchnorm = tf_utils.batch_norm(d6_deconv,
                                               name='d6_batchnorm',
                                               _ops=self._ops)
            d6_concat = tf.concat([d6_batchnorm, e0_conv2d],
                                  axis=3,
                                  name='d6_concat')
            d6_relu = tf_utils.relu(d6_concat, name='d6_relu')

            # (150, 100) -> (300, 200)
            d7_deconv = tf_utils.deconv2d(d6_relu,
                                          self.gen_c[15],
                                          name='d7_deconv2d')
            output = tf_utils.tanh(d7_deconv, name='output_tanh')

            # set reuse=True for next call
            self.reuse = True
            self.variables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)

            return output
示例#18
0
    def __call__(self, x):
        with tf.variable_scope(self.name, reuse=self.reuse):
            tf_utils.print_activations(x)

            # conv: (N, H, W, C) -> (N, H/2, W/2, 64)
            output = tf_utils.conv2d(x,
                                     self.conv_dims[0],
                                     k_h=4,
                                     k_w=4,
                                     d_h=2,
                                     d_w=2,
                                     padding='SAME',
                                     name='conv0_conv2d')
            output = tf_utils.lrelu(output, name='conv0_lrelu', is_print=True)

            for idx, conv_dim in enumerate(self.conv_dims[1:]):
                # conv: (N, H/2, W/2, C) -> (N, H/4, W/4, 2C)
                output = tf_utils.conv2d(output,
                                         conv_dim,
                                         k_h=4,
                                         k_w=4,
                                         d_h=2,
                                         d_w=2,
                                         padding='SAME',
                                         name='conv{}_conv2d'.format(idx + 1))
                output = tf_utils.norm(output,
                                       _type=self.norm,
                                       _ops=self._ops,
                                       name='conv{}_norm'.format(idx + 1))
                output = tf_utils.lrelu(output,
                                        name='conv{}_lrelu'.format(idx + 1),
                                        is_print=True)

            for idx, deconv_dim in enumerate(self.deconv_dims):
                # deconv: (N, H/16, W/16, C) -> (N, W/8, H/8, C/2)
                output = tf_utils.deconv2d(output,
                                           deconv_dim,
                                           k_h=4,
                                           k_w=4,
                                           name='deconv{}_conv2d'.format(idx))
                output = tf_utils.norm(output,
                                       _type=self.norm,
                                       _ops=self._ops,
                                       name='deconv{}_norm'.format(idx))
                output = tf_utils.relu(output,
                                       name='deconv{}_relu'.format(idx),
                                       is_print=True)

            # conv: (N, H/2, W/2, 64) -> (N, W, H, 3)
            output = tf_utils.deconv2d(output,
                                       self.output_channel,
                                       k_h=4,
                                       k_w=4,
                                       name='conv3_deconv2d')
            output = tf_utils.tanh(output, name='conv4_tanh', is_print=True)

            # set reuse=True for next call
            self.reuse = True
            self.variables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
            return output
示例#19
0
    def __call__(self, x, is_train=True):
        with tf.variable_scope(self.name, reuse=self.reuse):
            tf_utils.print_activations(x)

            # (N, 120, 160, 1) -> (N, 60, 80, 64)
            h0_conv = tf_utils.conv2d(
                x,
                output_dim=self.dims[0],
                initializer='he',
                name='h0_conv',
                logger=self.logger if is_train is True else None)
            h0_lrelu = tf_utils.lrelu(
                h0_conv,
                name='h0_lrelu',
                logger=self.logger if is_train is True else None)

            # (N, 60, 80, 64) -> (N, 30, 40, 128)
            h1_conv = tf_utils.conv2d(
                h0_lrelu,
                output_dim=self.dims[1],
                initializer='he',
                name='h1_conv',
                logger=self.logger if is_train is True else None)
            h1_norm = tf_utils.norm(
                h1_conv,
                name='h1_batch',
                _type='batch',
                _ops=self._ops,
                is_train=is_train,
                logger=self.logger if is_train is True else None)
            h1_lrelu = tf_utils.lrelu(
                h1_norm,
                name='h1_lrelu',
                logger=self.logger if is_train is True else None)

            # (N, 30, 40, 128) -> (N, 15, 20, 256)
            h2_conv = tf_utils.conv2d(
                h1_lrelu,
                output_dim=self.dims[2],
                initializer='he',
                name='h2_conv',
                logger=self.logger if is_train is True else None)
            h2_norm = tf_utils.norm(
                h2_conv,
                name='h2_batch',
                _type='batch',
                _ops=self._ops,
                is_train=is_train,
                logger=self.logger if is_train is True else None)
            h2_lrelu = tf_utils.lrelu(
                h2_norm,
                name='h2_lrelu',
                logger=self.logger if is_train is True else None)

            # (N, 15, 20, 256) -> (N, 8, 10, 512)
            h3_conv = tf_utils.conv2d(
                h2_lrelu,
                output_dim=self.dims[3],
                initializer='he',
                name='h3_conv',
                logger=self.logger if is_train is True else None)
            h3_norm = tf_utils.norm(
                h3_conv,
                name='h3_batch',
                _type='batch',
                _ops=self._ops,
                is_train=is_train,
                logger=self.logger if is_train is True else None)
            h3_lrelu = tf_utils.lrelu(
                h3_norm,
                name='h3_lrelu',
                logger=self.logger if is_train is True else None)

            # (N, 8, 10, 512) -> (N, 4, 5, 1024)
            h4_conv = tf_utils.conv2d(
                h3_lrelu,
                output_dim=self.dims[4],
                initializer='he',
                name='h4_conv',
                logger=self.logger if is_train is True else None)
            h4_norm = tf_utils.norm(
                h4_conv,
                name='h4_batch',
                _type='batch',
                _ops=self._ops,
                is_train=is_train,
                logger=self.logger if is_train is True else None)
            h4_lrelu = tf_utils.lrelu(
                h4_norm,
                name='h4_lrelu',
                logger=self.logger if is_train is True else None)
            # (N, 4, 5, 1024) -> (N, 4*5*1024)
            h4_flatten = tf_utils.flatten(
                h4_lrelu,
                name='h4_flatten',
                logger=self.logger if is_train is True else None)

            # (N, 4*5*1024) -> (N, 1)
            output = tf_utils.linear(
                h4_flatten,
                output_size=self.dims[5],
                initializer='he',
                name='output',
                logger=self.logger if is_train is True else None)

            # Set reuse=True for next call
            self.reuse = True
            self.variables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)

            return output
示例#20
0
    def generator(self, data, name='g_'):
        with tf.variable_scope(name):
            data_flatten = flatten(data)
            tf_utils.print_activations(data_flatten)

            # from (N, 128) to (N, 2, 4, 512)
            h0_linear = tf_utils.linear(data_flatten,
                                        self.gen_c[0],
                                        name='h0_linear')
            h0_reshape = tf.reshape(
                h0_linear,
                [tf.shape(h0_linear)[0], 2, 4,
                 int(self.gen_c[0] / (2 * 4))])

            # (N, 4, 8, 512)
            resblock_1 = tf_utils.res_block_v2(h0_reshape,
                                               self.gen_c[1],
                                               filter_size=3,
                                               _ops=self.gen_train_ops,
                                               norm_='batch',
                                               resample='up',
                                               name='res_block_1')
            # (N, 8, 16, 256)
            resblock_2 = tf_utils.res_block_v2(resblock_1,
                                               self.gen_c[2],
                                               filter_size=3,
                                               _ops=self.gen_train_ops,
                                               norm_='batch',
                                               resample='up',
                                               name='res_block_2')
            # (N, 16, 32, 128)
            resblock_3 = tf_utils.res_block_v2(resblock_2,
                                               self.gen_c[3],
                                               filter_size=3,
                                               _ops=self.gen_train_ops,
                                               norm_='batch',
                                               resample='up',
                                               name='res_block_3')
            # (N, 32, 64, 64)
            resblock_4 = tf_utils.res_block_v2(resblock_3,
                                               self.gen_c[4],
                                               filter_size=3,
                                               _ops=self.gen_train_ops,
                                               norm_='batch',
                                               resample='up',
                                               name='res_block_4')
            # (N, 64, 128, 32)
            resblock_5 = tf_utils.res_block_v2(resblock_4,
                                               self.gen_c[5],
                                               filter_size=3,
                                               _ops=self.gen_train_ops,
                                               norm_='batch',
                                               resample='up',
                                               name='res_block_5')
            # (N, 128, 256, 32)
            resblock_6 = tf_utils.res_block_v2(resblock_5,
                                               self.gen_c[6],
                                               filter_size=3,
                                               _ops=self.gen_train_ops,
                                               norm_='batch',
                                               resample='up',
                                               name='res_block_6')

            norm_7 = tf_utils.norm(resblock_6,
                                   _type='batch',
                                   _ops=self.gen_train_ops,
                                   name='norm_7')
            relu_7 = tf_utils.relu(norm_7, name='relu_7')

            # (N, 128, 256, 3)
            output = tf_utils.conv2d(relu_7,
                                     output_dim=self.image_size[2],
                                     k_w=3,
                                     k_h=3,
                                     d_h=1,
                                     d_w=1,
                                     name='output')

            return tf_utils.tanh(output)
    def forward_network(self, inputImg, padding='SAME', reuse=False):
        with tf.compat.v1.variable_scope(self.name, reuse=reuse):
            # This part is for compatible between input size [640, 400] and [320, 200]
            if self.resize_factor == 1.0:
                # Stage 0
                tf_utils.print_activations(inputImg, logger=self.logger)
                s0_conv1 = tf_utils.conv2d(x=inputImg, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1,
                                           padding=padding, initializer='He', name='s0_conv1', logger=self.logger)
                s0_conv1 = tf_utils.relu(s0_conv1, name='relu_s0_conv1', logger=self.logger)

                s0_conv2 = tf_utils.conv2d(x=s0_conv1, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1,
                                           padding=padding, initializer='He', name='s0_conv2', logger=self.logger)
                if self.use_batch_norm:
                    s0_conv2 = tf_utils.norm(s0_conv2, name='s0_norm1', _type='batch', _ops=self._ops,
                                             is_train=self.trainMode, logger=self.logger)
                s0_conv2 = tf_utils.relu(s0_conv2, name='relu_s0_conv2', logger=self.logger)

                # Stage 1
                s1_maxpool = tf_utils.max_pool(x=s0_conv2, name='s1_maxpool2d', logger=self.logger)

                s1_conv1 = tf_utils.conv2d(x=s1_maxpool, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1,
                                           padding=padding, initializer='He', name='s1_conv1', logger=self.logger)
                if self.use_batch_norm:
                    s1_conv1 = tf_utils.norm(s1_conv1, name='s1_norm0', _type='batch', _ops=self._ops,
                                             is_train=self.trainMode, logger=self.logger)
                s1_conv1 = tf_utils.relu(s1_conv1, name='relu_s1_conv1', logger=self.logger)

                s1_conv2 = tf_utils.conv2d(x=s1_conv1, output_dim=self.conv_dims[1], k_h=3, k_w=3, d_h=1, d_w=1,
                                           padding=padding, initializer='He', name='s1_conv2', logger=self.logger)
                if self.use_batch_norm:
                    s1_conv2 = tf_utils.norm(s1_conv2, name='s1_norm1', _type='batch', _ops=self._ops,
                                             is_train=self.trainMode, logger=self.logger)
                s1_conv2 = tf_utils.relu(s1_conv2, name='relu_s1_conv2', logger=self.logger)
            else:
                # Stage 1
                tf_utils.print_activations(inputImg, logger=self.logger)
                s1_conv1 = tf_utils.conv2d(x=inputImg, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1,
                                           padding=padding, initializer='He', name='s1_conv1', logger=self.logger)
                s1_conv1 = tf_utils.relu(s1_conv1, name='relu_s1_conv1', logger=self.logger)

                s1_conv2 = tf_utils.conv2d(x=s1_conv1, output_dim=self.conv_dims[1], k_h=3, k_w=3, d_h=1, d_w=1,
                                           padding=padding, initializer='He', name='s1_conv2', logger=self.logger)
                if self.use_batch_norm:
                    s1_conv2 = tf_utils.norm(s1_conv2, name='s1_norm1', _type='batch', _ops=self._ops,
                                             is_train=self.trainMode, logger=self.logger)
                s1_conv2 = tf_utils.relu(s1_conv2, name='relu_s1_conv2', logger=self.logger)

            # Stage 2
            s2_maxpool = tf_utils.max_pool(x=s1_conv2, name='s2_maxpool2d', logger=self.logger)
            s2_conv1 = tf_utils.conv2d(x=s2_maxpool, output_dim=self.conv_dims[2], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s2_conv1', logger=self.logger)
            if self.use_batch_norm:
                s2_conv1 = tf_utils.norm(s2_conv1, name='s2_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s2_conv1 = tf_utils.relu(s2_conv1, name='relu_s2_conv1', logger=self.logger)

            s2_conv2 = tf_utils.conv2d(x=s2_conv1, output_dim=self.conv_dims[3], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s2_conv2', logger=self.logger)
            if self.use_batch_norm:
                s2_conv2 = tf_utils.norm(s2_conv2, name='s2_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s2_conv2 = tf_utils.relu(s2_conv2, name='relu_s2_conv2', logger=self.logger)

            # Stage 3
            s3_maxpool = tf_utils.max_pool(x=s2_conv2, name='s3_maxpool2d', logger=self.logger)
            s3_conv1 = tf_utils.conv2d(x=s3_maxpool, output_dim=self.conv_dims[4], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s3_conv1', logger=self.logger)
            if self.use_batch_norm:
                s3_conv1 = tf_utils.norm(s3_conv1, name='s3_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s3_conv1 = tf_utils.relu(s3_conv1, name='relu_s3_conv1', logger=self.logger)

            s3_conv2 = tf_utils.conv2d(x=s3_conv1, output_dim=self.conv_dims[5], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s3_conv2', logger=self.logger)
            if self.use_batch_norm:
                s3_conv2 = tf_utils.norm(s3_conv2, name='s3_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s3_conv2 = tf_utils.relu(s3_conv2, name='relu_s3_conv2', logger=self.logger)

            # Stage 4
            s4_maxpool = tf_utils.max_pool(x=s3_conv2, name='s4_maxpool2d', logger=self.logger)
            s4_conv1 = tf_utils.conv2d(x=s4_maxpool, output_dim=self.conv_dims[6], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s4_conv1', logger=self.logger)
            if self.use_batch_norm:
                s4_conv1 = tf_utils.norm(s4_conv1, name='s4_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s4_conv1 = tf_utils.relu(s4_conv1, name='relu_s4_conv1', logger=self.logger)

            s4_conv2 = tf_utils.conv2d(x=s4_conv1, output_dim=self.conv_dims[7], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s4_conv2', logger=self.logger)
            if self.use_batch_norm:
                s4_conv2 = tf_utils.norm(s4_conv2, name='s4_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s4_conv2 = tf_utils.relu(s4_conv2, name='relu_s4_conv2', logger=self.logger)
            s4_conv2_drop = tf_utils.dropout(x=s4_conv2, keep_prob=self.ratePh, name='s4_dropout',
                                             logger=self.logger)

            # Stage 5
            s5_maxpool = tf_utils.max_pool(x=s4_conv2_drop, name='s5_maxpool2d', logger=self.logger)
            s5_conv1 = tf_utils.conv2d(x=s5_maxpool, output_dim=self.conv_dims[8], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s5_conv1', logger=self.logger)
            if self.use_batch_norm:
                s5_conv1 = tf_utils.norm(s5_conv1, name='s5_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s5_conv1 = tf_utils.relu(s5_conv1, name='relu_s5_conv1', logger=self.logger)

            s5_conv2 = tf_utils.conv2d(x=s5_conv1, output_dim=self.conv_dims[9], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s5_conv2', logger=self.logger)
            if self.use_batch_norm:
                s5_conv2 = tf_utils.norm(s5_conv2, name='s5_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s5_conv2 = tf_utils.relu(s5_conv2, name='relu_s5_conv2', logger=self.logger)
            s5_conv2_drop = tf_utils.dropout(x=s5_conv2, keep_prob=self.ratePh, name='s5_dropout',
                                             logger=self.logger)

            # Stage 6
            s6_deconv1 = tf_utils.deconv2d(x=s5_conv2_drop, output_dim=self.conv_dims[10], k_h=2, k_w=2,
                                           initializer='He', name='s6_deconv1', logger=self.logger)
            if self.use_batch_norm:
                s6_deconv1 = tf_utils.norm(s6_deconv1, name='s6_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s6_deconv1 = tf_utils.relu(s6_deconv1, name='relu_s6_deconv1', logger=self.logger)
            # Cropping
            w1 = s4_conv2_drop.get_shape().as_list()[2]
            w2 = s6_deconv1.get_shape().as_list()[2] - s4_conv2_drop.get_shape().as_list()[2]
            s6_deconv1_split, _ = tf.split(s6_deconv1, num_or_size_splits=[w1, w2], axis=2, name='axis2_split')
            tf_utils.print_activations(s6_deconv1_split, logger=self.logger)
            # Concat
            s6_concat = tf_utils.concat(values=[s6_deconv1_split, s4_conv2_drop], axis=3, name='s6_axis3_concat',
                                        logger=self.logger)

            s6_conv2 = tf_utils.conv2d(x=s6_concat, output_dim=self.conv_dims[11], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s6_conv2', logger=self.logger)
            if self.use_batch_norm:
                s6_conv2 = tf_utils.norm(s6_conv2, name='s6_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s6_conv2 = tf_utils.relu(s6_conv2, name='relu_s6_conv2', logger=self.logger)

            s6_conv3 = tf_utils.conv2d(x=s6_conv2, output_dim=self.conv_dims[12], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s6_conv3', logger=self.logger)
            if self.use_batch_norm:
                s6_conv3 = tf_utils.norm(s6_conv3, name='s6_norm2', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s6_conv3 = tf_utils.relu(s6_conv3, name='relu_s6_conv3', logger=self.logger)

            # Stage 7
            s7_deconv1 = tf_utils.deconv2d(x=s6_conv3, output_dim=self.conv_dims[13], k_h=2, k_w=2, initializer='He',
                                           name='s7_deconv1', logger=self.logger)
            if self.use_batch_norm:
                s7_deconv1 = tf_utils.norm(s7_deconv1, name='s7_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s7_deconv1 = tf_utils.relu(s7_deconv1, name='relu_s7_deconv1', logger=self.logger)
            # Concat
            s7_concat = tf_utils.concat(values=[s7_deconv1, s3_conv2], axis=3, name='s7_axis3_concat',
                                        logger=self.logger)

            s7_conv2 = tf_utils.conv2d(x=s7_concat, output_dim=self.conv_dims[14], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s7_conv2', logger=self.logger)
            if self.use_batch_norm:
                s7_conv2 = tf_utils.norm(s7_conv2, name='s7_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s7_conv2 = tf_utils.relu(s7_conv2, name='relu_s7_conv2', logger=self.logger)

            s7_conv3 = tf_utils.conv2d(x=s7_conv2, output_dim=self.conv_dims[15], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s7_conv3', logger=self.logger)
            if self.use_batch_norm:
                s7_conv3 = tf_utils.norm(s7_conv3, name='s7_norm2', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s7_conv3 = tf_utils.relu(s7_conv3, name='relu_s7_conv3', logger=self.logger)

            # Stage 8
            s8_deconv1 = tf_utils.deconv2d(x=s7_conv3, output_dim=self.conv_dims[16], k_h=2, k_w=2, initializer='He',
                                           name='s8_deconv1', logger=self.logger)
            if self.use_batch_norm:
                s8_deconv1 = tf_utils.norm(s8_deconv1, name='s8_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s8_deconv1 = tf_utils.relu(s8_deconv1, name='relu_s8_deconv1', logger=self.logger)
            # Concat
            s8_concat = tf_utils.concat(values=[s8_deconv1,s2_conv2], axis=3, name='s8_axis3_concat',
                                        logger=self.logger)

            s8_conv2 = tf_utils.conv2d(x=s8_concat, output_dim=self.conv_dims[17], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s8_conv2', logger=self.logger)
            if self.use_batch_norm:
                s8_conv2 = tf_utils.norm(s8_conv2, name='s8_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s8_conv2 = tf_utils.relu(s8_conv2, name='relu_s8_conv2', logger=self.logger)

            s8_conv3 = tf_utils.conv2d(x=s8_conv2, output_dim=self.conv_dims[18], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s8_conv3', logger=self.logger)
            if self.use_batch_norm:
                s8_conv3 = tf_utils.norm(s8_conv3, name='s8_norm2', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s8_conv3 = tf_utils.relu(s8_conv3, name='relu_conv3', logger=self.logger)

            # Stage 9
            s9_deconv1 = tf_utils.deconv2d(x=s8_conv3, output_dim=self.conv_dims[19], k_h=2, k_w=2,
                                           initializer='He', name='s9_deconv1', logger=self.logger)
            if self.use_batch_norm:
                s9_deconv1 = tf_utils.norm(s9_deconv1, name='s9_norm0', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s9_deconv1 = tf_utils.relu(s9_deconv1, name='relu_s9_deconv1', logger=self.logger)
            # Concat
            s9_concat = tf_utils.concat(values=[s9_deconv1, s1_conv2], axis=3, name='s9_axis3_concat',
                                        logger=self.logger)

            s9_conv2 = tf_utils.conv2d(x=s9_concat, output_dim=self.conv_dims[20], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s9_conv2', logger=self.logger)
            if self.use_batch_norm:
                s9_conv2 = tf_utils.norm(s9_conv2, name='s9_norm1', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s9_conv2 = tf_utils.relu(s9_conv2, name='relu_s9_conv2', logger=self.logger)

            s9_conv3 = tf_utils.conv2d(x=s9_conv2, output_dim=self.conv_dims[21], k_h=3, k_w=3, d_h=1, d_w=1,
                                       padding=padding, initializer='He', name='s9_conv3', logger=self.logger)
            if self.use_batch_norm:
                s9_conv3 = tf_utils.norm(s9_conv3, name='s9_norm2', _type='batch', _ops=self._ops,
                                         is_train=self.trainMode, logger=self.logger)
            s9_conv3 = tf_utils.relu(s9_conv3, name='relu_s9_conv3', logger=self.logger)

            if self.resize_factor == 1.0:
                s10_deconv1 = tf_utils.deconv2d(x=s9_conv3, output_dim=self.conv_dims[-1], k_h=2, k_w=2,
                                                initializer='He', name='s10_deconv1', logger=self.logger)
                if self.use_batch_norm:
                    s10_deconv1 = tf_utils.norm(s10_deconv1, name='s10_norm0', _type='batch', _ops=self._ops,
                                             is_train=self.trainMode, logger=self.logger)
                s10_deconv1 = tf_utils.relu(s10_deconv1, name='relu_s10_deconv1', logger=self.logger)
                # Concat
                s10_concat = tf_utils.concat(values=[s10_deconv1, s0_conv2], axis=3, name='s10_axis3_concat',
                                             logger=self.logger)

                s10_conv2 = tf_utils.conv2d(s10_concat, output_dim=self.conv_dims[-1], k_h=3, k_w=3, d_h=1, d_w=1,
                                            padding=padding, initializer='He', name='s10_conv2', logger=self.logger)
                if self.use_batch_norm:
                    s10_conv2 = tf_utils.norm(s10_conv2, name='s10_norm1', _type='batch', _ops=self._ops,
                                             is_train=self.trainMode, logger=self.logger)
                s10_conv2 = tf_utils.relu(s10_conv2, name='relu_s10_conv2', logger=self.logger)

                s10_conv3 = tf_utils.conv2d(x=s10_conv2, output_dim=self.conv_dims[-1], k_h=3, k_w=3, d_h=1, d_w=1,
                                            padding=padding, initializer='He', name='s10_conv3', logger=self.logger)
                if self.use_batch_norm:
                    s10_conv3 = tf_utils.norm(s10_conv3, name='s10_norm2', _type='batch', _ops=self._ops,
                                             is_train=self.trainMode, logger=self.logger)
                s10_conv3 = tf_utils.relu(s10_conv3, name='relu_s10_conv3', logger=self.logger)

                output = tf_utils.conv2d(s10_conv3, output_dim=self.numClasses, k_h=1, k_w=1, d_h=1, d_w=1,
                                         padding=padding, initializer='He', name='output', logger=self.logger)
            else:
                output = tf_utils.conv2d(s9_conv3, output_dim=self.numClasses, k_h=1, k_w=1, d_h=1, d_w=1,
                                         padding=padding, initializer='He', name='output', logger=self.logger)

            return output
    def __call__(self, x, keep_rate=0.5):
        with tf.compat.v1.variable_scope(self.name, reuse=self.reuse):
            tf_utils.print_activations(x, logger=self.logger)

            # E0: (320, 200) -> (160, 100)
            e0_conv2d = tf_utils.conv2d(x,
                                        output_dim=self.gen_c[0],
                                        initializer='He',
                                        logger=self.logger,
                                        name='e0_conv2d')
            e0_lrelu = tf_utils.lrelu(e0_conv2d,
                                      logger=self.logger,
                                      name='e0_lrelu')

            # E1: (160, 100) -> (80, 50)
            e1_conv2d = tf_utils.conv2d(e0_lrelu,
                                        output_dim=self.gen_c[1],
                                        initializer='He',
                                        logger=self.logger,
                                        name='e1_conv2d')
            e1_batchnorm = tf_utils.norm(e1_conv2d,
                                         _type=self.norm,
                                         _ops=self._ops,
                                         logger=self.logger,
                                         name='e1_norm')
            e1_lrelu = tf_utils.lrelu(e1_batchnorm,
                                      logger=self.logger,
                                      name='e1_lrelu')

            # E2: (80, 50) -> (40, 25)
            e2_conv2d = tf_utils.conv2d(e1_lrelu,
                                        output_dim=self.gen_c[2],
                                        initializer='He',
                                        logger=self.logger,
                                        name='e2_conv2d')
            e2_batchnorm = tf_utils.norm(e2_conv2d,
                                         _type=self.norm,
                                         _ops=self._ops,
                                         logger=self.logger,
                                         name='e2_norm')
            e2_lrelu = tf_utils.lrelu(e2_batchnorm,
                                      logger=self.logger,
                                      name='e2_lrelu')

            # E3: (40, 25) -> (20, 13)
            e3_conv2d = tf_utils.conv2d(e2_lrelu,
                                        output_dim=self.gen_c[3],
                                        initializer='He',
                                        logger=self.logger,
                                        name='e3_conv2d')
            e3_batchnorm = tf_utils.norm(e3_conv2d,
                                         _type=self.norm,
                                         _ops=self._ops,
                                         logger=self.logger,
                                         name='e3_norm')
            e3_lrelu = tf_utils.lrelu(e3_batchnorm,
                                      logger=self.logger,
                                      name='e3_lrelu')

            # E4: (20, 13) -> (10, 7)
            e4_conv2d = tf_utils.conv2d(e3_lrelu,
                                        output_dim=self.gen_c[4],
                                        initializer='He',
                                        logger=self.logger,
                                        name='e4_conv2d')
            e4_batchnorm = tf_utils.norm(e4_conv2d,
                                         _type=self.norm,
                                         _ops=self._ops,
                                         logger=self.logger,
                                         name='e4_norm')
            e4_lrelu = tf_utils.lrelu(e4_batchnorm,
                                      logger=self.logger,
                                      name='e4_lrelu')

            # E5: (10, 7) -> (5, 4)
            e5_conv2d = tf_utils.conv2d(e4_lrelu,
                                        output_dim=self.gen_c[5],
                                        initializer='He',
                                        logger=self.logger,
                                        name='e5_conv2d')
            e5_batchnorm = tf_utils.norm(e5_conv2d,
                                         _type=self.norm,
                                         _ops=self._ops,
                                         logger=self.logger,
                                         name='e5_norm')
            e5_lrelu = tf_utils.lrelu(e5_batchnorm,
                                      logger=self.logger,
                                      name='e5_lrelu')

            # E6: (5, 4) -> (3, 2)
            e6_conv2d = tf_utils.conv2d(e5_lrelu,
                                        output_dim=self.gen_c[6],
                                        initializer='He',
                                        logger=self.logger,
                                        name='e6_conv2d')
            e6_batchnorm = tf_utils.norm(e6_conv2d,
                                         _type=self.norm,
                                         _ops=self._ops,
                                         logger=self.logger,
                                         name='e6_norm')
            e6_lrelu = tf_utils.lrelu(e6_batchnorm,
                                      logger=self.logger,
                                      name='e6_lrelu')

            # E7: (3, 2) -> (2, 1)
            e7_conv2d = tf_utils.conv2d(e6_lrelu,
                                        output_dim=self.gen_c[7],
                                        initializer='He',
                                        logger=self.logger,
                                        name='e7_conv2d')
            e7_batchnorm = tf_utils.norm(e7_conv2d,
                                         _type=self.norm,
                                         _ops=self._ops,
                                         logger=self.logger,
                                         name='e7_norm')
            e7_relu = tf_utils.lrelu(e7_batchnorm,
                                     logger=self.logger,
                                     name='e7_relu')

            # D0: (2, 1) -> (3, 2)
            # Stage1: (2, 1) -> (4, 2)
            d0_deconv = tf_utils.deconv2d(e7_relu,
                                          output_dim=self.gen_c[8],
                                          initializer='He',
                                          logger=self.logger,
                                          name='d0_deconv2d')
            # Stage2: (4, 2) -> (3, 2)
            shapeA = e6_conv2d.get_shape().as_list()[1]
            shapeB = d0_deconv.get_shape().as_list()[1] - e6_conv2d.get_shape(
            ).as_list()[1]
            d0_split, _ = tf.split(d0_deconv, [shapeA, shapeB],
                                   axis=1,
                                   name='d0_split')
            tf_utils.print_activations(d0_split, logger=self.logger)
            # Stage3: Batch norm, concatenation, and relu
            d0_batchnorm = tf_utils.norm(d0_split,
                                         _type=self.norm,
                                         _ops=self._ops,
                                         logger=self.logger,
                                         name='d0_norm')
            d0_drop = tf_utils.dropout(d0_batchnorm,
                                       keep_prob=keep_rate,
                                       logger=self.logger,
                                       name='d0_dropout')
            d0_concat = tf.concat([d0_drop, e6_batchnorm],
                                  axis=3,
                                  name='d0_concat')
            d0_relu = tf_utils.relu(d0_concat,
                                    logger=self.logger,
                                    name='d0_relu')

            # D1: (3, 2) -> (5, 4)
            # Stage1: (3, 2) -> (6, 4)
            d1_deconv = tf_utils.deconv2d(d0_relu,
                                          output_dim=self.gen_c[9],
                                          initializer='He',
                                          logger=self.logger,
                                          name='d1_deconv2d')
            # Stage2: (6, 4) -> (5, 4)
            shapeA = e5_batchnorm.get_shape().as_list()[1]
            shapeB = d1_deconv.get_shape().as_list(
            )[1] - e5_batchnorm.get_shape().as_list()[1]
            d1_split, _ = tf.split(d1_deconv, [shapeA, shapeB],
                                   axis=1,
                                   name='d1_split')
            tf_utils.print_activations(d1_split, logger=self.logger)
            # Stage3: Batch norm, concatenation, and relu
            d1_batchnorm = tf_utils.norm(d1_split,
                                         _type=self.norm,
                                         _ops=self._ops,
                                         logger=self.logger,
                                         name='d1_norm')
            d1_drop = tf_utils.dropout(d1_batchnorm,
                                       keep_prob=keep_rate,
                                       logger=self.logger,
                                       name='d1_dropout')
            d1_concat = tf.concat([d1_drop, e5_batchnorm],
                                  axis=3,
                                  name='d1_concat')
            d1_relu = tf_utils.relu(d1_concat,
                                    logger=self.logger,
                                    name='d1_relu')

            # D2: (5, 4) -> (10, 7)
            # Stage1: (5, 4) -> (10, 8)
            d2_deconv = tf_utils.deconv2d(d1_relu,
                                          output_dim=self.gen_c[10],
                                          initializer='He',
                                          logger=self.logger,
                                          name='d2_deconv2d')
            # Stage2: (10, 8) -> (10, 7)
            shapeA = e4_batchnorm.get_shape().as_list()[2]
            shapeB = d2_deconv.get_shape().as_list(
            )[2] - e4_batchnorm.get_shape().as_list()[2]
            d2_split, _ = tf.split(d2_deconv, [shapeA, shapeB],
                                   axis=2,
                                   name='d2_split')
            tf_utils.print_activations(d2_split, logger=self.logger)
            # Stage3: Batch norm, concatenation, and relu
            d2_batchnorm = tf_utils.norm(d2_split,
                                         _type=self.norm,
                                         _ops=self._ops,
                                         logger=self.logger,
                                         name='d2_norm')
            d2_drop = tf_utils.dropout(d2_batchnorm,
                                       keep_prob=keep_rate,
                                       logger=self.logger,
                                       name='d2_dropout')
            d2_concat = tf.concat([d2_drop, e4_batchnorm],
                                  axis=3,
                                  name='d2_concat')
            d2_relu = tf_utils.relu(d2_concat,
                                    logger=self.logger,
                                    name='d2_relu')

            # D3: (10, 7) -> (20, 13)
            # Stage1: (10, 7) -> (20, 14)
            d3_deconv = tf_utils.deconv2d(d2_relu,
                                          output_dim=self.gen_c[11],
                                          initializer='He',
                                          logger=self.logger,
                                          name='d3_deconv2d')
            # Stage2: (20, 14) -> (20, 13)
            shapeA = e3_batchnorm.get_shape().as_list()[2]
            shapeB = d3_deconv.get_shape().as_list(
            )[2] - e3_batchnorm.get_shape().as_list()[2]
            d3_split, _ = tf.split(d3_deconv, [shapeA, shapeB],
                                   axis=2,
                                   name='d3_split_2')
            tf_utils.print_activations(d3_split, logger=self.logger)
            # Stage3: Batch norm, concatenation, and relu
            d3_batchnorm = tf_utils.norm(d3_split,
                                         _type=self.norm,
                                         _ops=self._ops,
                                         logger=self.logger,
                                         name='d3_norm')
            d3_concat = tf.concat([d3_batchnorm, e3_batchnorm],
                                  axis=3,
                                  name='d3_concat')
            d3_relu = tf_utils.relu(d3_concat,
                                    logger=self.logger,
                                    name='d3_relu')

            # D4: (20, 13) -> (40, 25)
            # Stage1: (20, 13) -> (40, 26)
            d4_deconv = tf_utils.deconv2d(d3_relu,
                                          output_dim=self.gen_c[12],
                                          initializer='He',
                                          logger=self.logger,
                                          name='d4_deconv2d')
            # Stage2: (40, 26) -> (40, 25)
            shapeA = e2_batchnorm.get_shape().as_list()[2]
            shapeB = d4_deconv.get_shape().as_list(
            )[2] - e2_batchnorm.get_shape().as_list()[2]
            d4_split, _ = tf.split(d4_deconv, [shapeA, shapeB],
                                   axis=2,
                                   name='d4_split')
            tf_utils.print_activations(d4_split, logger=self.logger)
            # Stage3: Batch norm, concatenation, and relu
            d4_batchnorm = tf_utils.norm(d4_split,
                                         _type=self.norm,
                                         _ops=self._ops,
                                         logger=self.logger,
                                         name='d4_norm')
            d4_concat = tf.concat([d4_batchnorm, e2_batchnorm],
                                  axis=3,
                                  name='d4_concat')
            d4_relu = tf_utils.relu(d4_concat,
                                    logger=self.logger,
                                    name='d4_relu')

            # D5: (40, 25, 256) -> (80, 50, 128)
            d5_deconv = tf_utils.deconv2d(d4_relu,
                                          output_dim=self.gen_c[13],
                                          initializer='He',
                                          logger=self.logger,
                                          name='d5_deconv2d')
            d5_batchnorm = tf_utils.norm(d5_deconv,
                                         _type=self.norm,
                                         _ops=self._ops,
                                         logger=self.logger,
                                         name='d5_norm')
            d5_concat = tf.concat([d5_batchnorm, e1_batchnorm],
                                  axis=3,
                                  name='d5_concat')
            d5_relu = tf_utils.relu(d5_concat,
                                    logger=self.logger,
                                    name='d5_relu')

            # D6: (80, 50, 128) -> (160, 100, 64)
            d6_deconv = tf_utils.deconv2d(d5_relu,
                                          output_dim=self.gen_c[14],
                                          initializer='He',
                                          logger=self.logger,
                                          name='d6_deconv2d')
            d6_batchnorm = tf_utils.norm(d6_deconv,
                                         _type=self.norm,
                                         _ops=self._ops,
                                         logger=self.logger,
                                         name='d6_norm')
            d6_concat = tf.concat([d6_batchnorm, e0_conv2d],
                                  axis=3,
                                  name='d6_concat')
            d6_relu = tf_utils.relu(d6_concat,
                                    logger=self.logger,
                                    name='d6_relu')

            # D7: (160, 100, 64) -> (320, 200, 1)
            d7_deconv = tf_utils.deconv2d(d6_relu,
                                          output_dim=self.gen_c[15],
                                          initializer='He',
                                          logger=self.logger,
                                          name='d7_deconv2d')
            output = tf_utils.tanh(d7_deconv,
                                   logger=self.logger,
                                   name='output_tanh')

            # Set reuse=True for next call
            self.reuse = True
            self.variables = tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)

        return output
示例#23
0
    def encoder(self, data, name='encoder'):
        with tf.variable_scope(name):
            data_flatten = flatten(data)
            tf_utils.print_activations(data_flatten)

            # 1st hidden layer
            h0_linear = tf_utils.linear(data_flatten,
                                        self.n_hidden,
                                        name='h0_linear')
            h0_elu = tf_utils.elu(h0_linear, name='h0_elu')
            h0_drop = tf.nn.dropout(h0_elu,
                                    keep_prob=self.keep_prob_tfph,
                                    name='h0_drop')
            tf_utils.print_activations(h0_drop)

            # 2nd hidden layer
            h1_linear = tf_utils.linear(h0_drop,
                                        self.n_hidden,
                                        name='h1_linear')
            h1_tanh = tf_utils.tanh(h1_linear, name='h1_tanh')
            h1_drop = tf.nn.dropout(h1_tanh,
                                    keep_prob=self.keep_prob_tfph,
                                    name='h1_drop')
            tf_utils.print_activations(h1_drop)

            # 3rd hidden layer
            h2_linear = tf_utils.linear(h1_drop,
                                        2 * self.flags.z_dim,
                                        name='h2_linear')
            tf_utils.print_activations(h2_linear)

            # The mean parameter is unconstrained
            mean = h2_linear[:, :self.flags.z_dim]
            # The standard deviation must be positive.
            # Parameterize with a softplus and add a small epsilon for numerical stability
            stddev = 1e-6 + tf.nn.softplus(h2_linear[:, self.flags.z_dim:])

            tf_utils.print_activations(mean)
            tf_utils.print_activations(stddev)

        return mean, stddev
示例#24
0
    def __call__(self, x, is_train=True):
        with tf.variable_scope(self.name, reuse=self.reuse):
            tf_utils.print_activations(x)

            # (N, 100) -> (N, 4, 5, 512)
            h0_linear = tf_utils.linear(
                x,
                4 * 5 * self.dims[0],
                name='h0_linear',
                initializer='He',
                logger=self.logger if is_train is True else None)
            h0_reshape = tf.reshape(
                h0_linear, [tf.shape(h0_linear)[0], 4, 5, self.dims[0]])

            # (N, 4, 5, 512) -> (N, 8, 10, 512)
            resblock_1 = tf_utils.res_block_v2(
                x=h0_reshape,
                k=self.dims[1],
                filter_size=3,
                _ops=self._ops,
                norm_='batch',
                resample='up',
                name='res_block_1',
                logger=self.logger if is_train is True else None)

            # (N, 8, 10, 512) -> (N, 16, 20, 256)
            resblock_2 = tf_utils.res_block_v2(
                x=resblock_1,
                k=self.dims[2],
                filter_size=3,
                _ops=self._ops,
                norm_='batch',
                resample='up',
                name='res_block_2',
                logger=self.logger if is_train is True else None)

            # (N, 16, 20, 256) -> (N, 15, 20, 256)
            resblock_2_split, _ = tf.split(resblock_2, [15, 1],
                                           axis=1,
                                           name='resblock_2_split')
            tf_utils.print_activations(
                resblock_2_split,
                logger=self.logger if is_train is True else None)

            # (N, 15, 20, 256) -> (N, 30, 40, 128)
            resblock_3 = tf_utils.res_block_v2(
                x=resblock_2_split,
                k=self.dims[3],
                filter_size=3,
                _ops=self._ops,
                norm_='batch',
                resample='up',
                name='res_block_3',
                logger=self.logger if is_train is True else None)

            # (N, 30, 40, 128) -> (N, 60, 80, 64)
            resblock_4 = tf_utils.res_block_v2(
                x=resblock_3,
                k=self.dims[4],
                filter_size=3,
                _ops=self._ops,
                norm_='batch',
                resample='up',
                name='res_block_4',
                logger=self.logger if is_train is True else None)

            # (N, 60, 80, 64) -> (N, 120, 160, 64)
            resblock_5 = tf_utils.res_block_v2(
                x=resblock_4,
                k=self.dims[5],
                filter_size=3,
                _ops=self._ops,
                norm_='batch',
                resample='up',
                name='res_block_5',
                logger=self.logger if is_train is True else None)

            norm_5 = tf_utils.norm(
                resblock_5,
                name='norm_5',
                _type='batch',
                _ops=self._ops,
                is_train=is_train,
                logger=self.logger if is_train is True else None)

            relu_5 = tf_utils.relu(
                norm_5,
                name='relu_5',
                logger=self.logger if is_train is True else None)

            # (N, 120, 160, 64) -> (N, 120, 160, 3)
            conv_6 = tf_utils.conv2d(
                relu_5,
                output_dim=self.dims[6],
                k_h=3,
                k_w=3,
                d_h=1,
                d_w=1,
                name='conv_6',
                logger=self.logger if is_train is True else None)

            output = tf_utils.tanh(
                conv_6,
                name='output',
                logger=self.logger if is_train is True else None)

        # Set reuse=True for next call
        self.reuse = True
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                           scope=self.name)

        return output
示例#25
0
    def __call__(self, x, is_train=True):
        with tf.variable_scope(self.name, reuse=self.reuse):
            tf_utils.print_activations(x)

            # (N, 100) -> (N, 4, 5, 1024)
            h0_linear = tf_utils.linear(
                x,
                4 * 5 * self.dims[0],
                name='h0_linear',
                initializer='He',
                logger=self.logger if is_train is True else None)
            h0_reshape = tf.reshape(
                h0_linear, [tf.shape(h0_linear)[0], 4, 5, self.dims[0]])
            h0_norm = tf_utils.norm(
                h0_reshape,
                name='h0_batch',
                _type='batch',
                _ops=self._ops,
                is_train=is_train,
                logger=self.logger if is_train is True else None)
            h0_relu = tf_utils.relu(
                h0_norm,
                name='h0_relu',
                logger=self.logger if is_train is True else None)

            # (N, 4, 5, 1024) -> (N, 8, 10, 512)
            h1_deconv = tf_utils.deconv2d(
                h0_relu,
                output_dim=self.dims[1],
                name='h1_deconv2d',
                initializer='He',
                logger=self.logger if is_train is True else None)
            h1_norm = tf_utils.norm(
                h1_deconv,
                name='h1_batch',
                _type='batch',
                _ops=self._ops,
                is_train=is_train,
                logger=self.logger if is_train is True else None)
            h1_relu = tf_utils.relu(
                h1_norm,
                name='h1_relu',
                logger=self.logger if is_train is True else None)

            # (N, 8, 10, 512) -> (N, 16, 20, 256)
            h2_deconv = tf_utils.deconv2d(
                h1_relu,
                output_dim=self.dims[2],
                name='h2_deconv2d',
                initializer='He',
                logger=self.logger if is_train is True else None)
            h2_norm = tf_utils.norm(
                h2_deconv,
                name='h2_batch',
                _type='batch',
                _ops=self._ops,
                is_train=is_train,
                logger=self.logger if is_train is True else None)
            h2_relu = tf_utils.relu(
                h2_norm,
                name='h2_relu',
                logger=self.logger if is_train is True else None)
            # (N, 16, 20, 256) -> (N, 15, 20, 256)
            h2_split, _ = tf.split(h2_relu, [15, 1], axis=1, name='h2_split')
            tf_utils.print_activations(
                h2_split, logger=self.logger if is_train is True else None)

            # (N, 15, 20, 256) -> (N, 30, 40, 128)
            h3_deconv = tf_utils.deconv2d(
                h2_split,
                output_dim=self.dims[3],
                name='h3_deconv2d',
                initializer='He',
                logger=self.logger if is_train is True else None)
            h3_norm = tf_utils.norm(
                h3_deconv,
                name='h3_batch',
                _type='batch',
                _ops=self._ops,
                is_train=is_train,
                logger=self.logger if is_train is True else None)
            h3_relu = tf_utils.relu(
                h3_norm,
                name='h3_relu',
                logger=self.logger if is_train is True else None)

            # (N, 30, 40, 128) -> (N, 60, 80, 64)
            h4_deconv = tf_utils.deconv2d(
                h3_relu,
                output_dim=self.dims[4],
                name='h4_deconv2d',
                initializer='He',
                logger=self.logger if is_train is True else None)
            h4_norm = tf_utils.norm(
                h4_deconv,
                name='h4_batch',
                _type='batch',
                _ops=self._ops,
                is_train=is_train,
                logger=self.logger if is_train is True else None)
            h4_relu = tf_utils.relu(
                h4_norm,
                name='h4_relu',
                logger=self.logger if is_train is True else None)

            # (N, 60, 80, 64) -> (N, 120, 160, 1)
            h5_deconv = tf_utils.deconv2d(
                h4_relu,
                output_dim=self.dims[5],
                name='h5_deconv',
                initializer='He',
                logger=self.logger if is_train is True else None)
            output = tf_utils.tanh(
                h5_deconv,
                name='output',
                logger=self.logger if is_train is True else None)

            # Set reuse=True for next call
            self.reuse = True
            self.variables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)

            return output
示例#26
0
    def u_net(self):
        # Stage 1
        tf_utils.print_activations(self.inp_img, logger=self.logger)
        s1_conv1 = tf_utils.conv2d(x=self.inp_img,
                                   output_dim=self.conv_dims[0],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s1_conv1',
                                   logger=self.logger)
        s1_conv1 = tf_utils.relu(s1_conv1,
                                 name='relu_s1_conv1',
                                 logger=self.logger)
        s1_conv2 = tf_utils.conv2d(x=s1_conv1,
                                   output_dim=self.conv_dims[1],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s1_conv2',
                                   logger=self.logger)
        s1_conv2 = tf_utils.relu(s1_conv2,
                                 name='relu_s1_conv2',
                                 logger=self.logger)

        # Stage 2
        s2_maxpool = tf_utils.max_pool(x=s1_conv2,
                                       name='s2_maxpool',
                                       logger=self.logger)
        s2_conv1 = tf_utils.conv2d(x=s2_maxpool,
                                   output_dim=self.conv_dims[2],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s2_conv1',
                                   logger=self.logger)
        s2_conv1 = tf_utils.relu(s2_conv1,
                                 name='relu_s2_conv1',
                                 logger=self.logger)
        s2_conv2 = tf_utils.conv2d(x=s2_conv1,
                                   output_dim=self.conv_dims[3],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s2_conv2',
                                   logger=self.logger)
        s2_conv2 = tf_utils.relu(s2_conv2,
                                 name='relu_s2_conv2',
                                 logger=self.logger)

        # Stage 3
        s3_maxpool = tf_utils.max_pool(x=s2_conv2,
                                       name='s3_maxpool',
                                       logger=self.logger)
        s3_conv1 = tf_utils.conv2d(x=s3_maxpool,
                                   output_dim=self.conv_dims[4],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s3_conv1',
                                   logger=self.logger)
        s3_conv1 = tf_utils.relu(s3_conv1,
                                 name='relu_s3_conv1',
                                 logger=self.logger)
        s3_conv2 = tf_utils.conv2d(x=s3_conv1,
                                   output_dim=self.conv_dims[5],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s3_conv2',
                                   logger=self.logger)
        s3_conv2 = tf_utils.relu(s3_conv2,
                                 name='relu_s3_conv2',
                                 logger=self.logger)

        # Stage 4
        s4_maxpool = tf_utils.max_pool(x=s3_conv2,
                                       name='s4_maxpool',
                                       logger=self.logger)
        s4_conv1 = tf_utils.conv2d(x=s4_maxpool,
                                   output_dim=self.conv_dims[6],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s4_conv1',
                                   logger=self.logger)
        s4_conv1 = tf_utils.relu(s4_conv1,
                                 name='relu_s4_conv1',
                                 logger=self.logger)
        s4_conv2 = tf_utils.conv2d(x=s4_conv1,
                                   output_dim=self.conv_dims[7],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s4_conv2',
                                   logger=self.logger)
        s4_conv2 = tf_utils.relu(s4_conv2,
                                 name='relu_s4_conv2',
                                 logger=self.logger)
        s4_conv2_drop = tf_utils.dropout(x=s4_conv2,
                                         keep_prob=self.keep_prob,
                                         name='s4_conv2_dropout',
                                         logger=self.logger)

        # Stage 5
        s5_maxpool = tf_utils.max_pool(x=s4_conv2_drop,
                                       name='s5_maxpool',
                                       logger=self.logger)
        s5_conv1 = tf_utils.conv2d(x=s5_maxpool,
                                   output_dim=self.conv_dims[8],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s5_conv1',
                                   logger=self.logger)
        s5_conv1 = tf_utils.relu(s5_conv1,
                                 name='relu_s5_conv1',
                                 logger=self.logger)
        s5_conv2 = tf_utils.conv2d(x=s5_conv1,
                                   output_dim=self.conv_dims[9],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s5_conv2',
                                   logger=self.logger)
        s5_conv2 = tf_utils.relu(s5_conv2,
                                 name='relu_s5_conv2',
                                 logger=self.logger)
        s5_conv2_drop = tf_utils.dropout(x=s5_conv2,
                                         keep_prob=self.keep_prob,
                                         name='s5_conv2_dropout',
                                         logger=self.logger)

        # Stage 6
        s6_deconv1 = tf_utils.deconv2d(x=s5_conv2_drop,
                                       output_dim=self.conv_dims[10],
                                       k_h=2,
                                       k_w=2,
                                       initializer='He',
                                       name='s6_deconv1',
                                       logger=self.logger)
        s6_deconv1 = tf_utils.relu(s6_deconv1,
                                   name='relu_s6_deconv1',
                                   logger=self.logger)

        # Cropping
        h1, w1 = s4_conv2_drop.get_shape().as_list()[1:3]
        h2, w2 = s6_deconv1.get_shape().as_list()[1:3]
        s4_conv2_crop = tf.image.crop_to_bounding_box(
            image=s4_conv2_drop,
            offset_height=int(0.5 * (h1 - h2)),
            offset_width=int(0.5 * (w1 - w2)),
            target_height=h2,
            target_width=w2)
        tf_utils.print_activations(s4_conv2_crop, logger=self.logger)

        s6_concat = tf_utils.concat(values=[s4_conv2_crop, s6_deconv1],
                                    axis=3,
                                    name='s6_concat',
                                    logger=self.logger)
        s6_conv2 = tf_utils.conv2d(x=s6_concat,
                                   output_dim=self.conv_dims[11],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s6_conv2',
                                   logger=self.logger)
        s6_conv2 = tf_utils.relu(s6_conv2,
                                 name='relu_s6_conv2',
                                 logger=self.logger)
        s6_conv3 = tf_utils.conv2d(x=s6_conv2,
                                   output_dim=self.conv_dims[12],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s6_conv3',
                                   logger=self.logger)
        s6_conv3 = tf_utils.relu(s6_conv3,
                                 name='relu_s6_conv3',
                                 logger=self.logger)

        # Stage 7
        s7_deconv1 = tf_utils.deconv2d(x=s6_conv3,
                                       output_dim=self.conv_dims[13],
                                       k_h=2,
                                       k_w=2,
                                       initializer='He',
                                       name='s7_deconv1',
                                       logger=self.logger)
        s7_deconv1 = tf_utils.relu(s7_deconv1,
                                   name='relu_s7_deconv1',
                                   logger=self.logger)
        # Cropping
        h1, w1 = s3_conv2.get_shape().as_list()[1:3]
        h2, w2 = s7_deconv1.get_shape().as_list()[1:3]
        s3_conv2_crop = tf.image.crop_to_bounding_box(
            image=s3_conv2,
            offset_height=int(0.5 * (h1 - h2)),
            offset_width=int(0.5 * (w1 - w2)),
            target_height=h2,
            target_width=w2)
        tf_utils.print_activations(s3_conv2_crop, logger=self.logger)

        s7_concat = tf_utils.concat(values=[s3_conv2_crop, s7_deconv1],
                                    axis=3,
                                    name='s7_concat',
                                    logger=self.logger)
        s7_conv2 = tf_utils.conv2d(x=s7_concat,
                                   output_dim=self.conv_dims[14],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s7_conv2',
                                   logger=self.logger)
        s7_conv2 = tf_utils.relu(s7_conv2,
                                 name='relu_s7_conv2',
                                 logger=self.logger)
        s7_conv3 = tf_utils.conv2d(x=s7_conv2,
                                   output_dim=self.conv_dims[15],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s7_conv3',
                                   logger=self.logger)
        s7_conv3 = tf_utils.relu(s7_conv3,
                                 name='relu_s7_conv3',
                                 logger=self.logger)

        # Stage 8
        s8_deconv1 = tf_utils.deconv2d(x=s7_conv3,
                                       output_dim=self.conv_dims[16],
                                       k_h=2,
                                       k_w=2,
                                       initializer='He',
                                       name='s8_deconv1',
                                       logger=self.logger)
        s8_deconv1 = tf_utils.relu(s8_deconv1,
                                   name='relu_s8_deconv1',
                                   logger=self.logger)
        # Cropping
        h1, w1 = s2_conv2.get_shape().as_list()[1:3]
        h2, w2 = s8_deconv1.get_shape().as_list()[1:3]
        s2_conv2_crop = tf.image.crop_to_bounding_box(
            image=s2_conv2,
            offset_height=int(0.5 * (h1 - h2)),
            offset_width=int(0.5 * (w1 - w2)),
            target_height=h2,
            target_width=w2)
        tf_utils.print_activations(s2_conv2_crop, logger=self.logger)

        s8_concat = tf_utils.concat(values=[s2_conv2_crop, s8_deconv1],
                                    axis=3,
                                    name='s8_concat',
                                    logger=self.logger)
        s8_conv2 = tf_utils.conv2d(x=s8_concat,
                                   output_dim=self.conv_dims[17],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s8_conv2',
                                   logger=self.logger)
        s8_conv2 = tf_utils.relu(s8_conv2,
                                 name='relu_s8_conv2',
                                 logger=self.logger)
        s8_conv3 = tf_utils.conv2d(x=s8_conv2,
                                   output_dim=self.conv_dims[18],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s8_conv3',
                                   logger=self.logger)
        s8_conv3 = tf_utils.relu(s8_conv3,
                                 name='relu_conv3',
                                 logger=self.logger)

        # Stage 9
        s9_deconv1 = tf_utils.deconv2d(x=s8_conv3,
                                       output_dim=self.conv_dims[19],
                                       k_h=2,
                                       k_w=2,
                                       initializer='He',
                                       name='s9_deconv1',
                                       logger=self.logger)
        s9_deconv1 = tf_utils.relu(s9_deconv1,
                                   name='relu_s9_deconv1',
                                   logger=self.logger)
        # Cropping
        h1, w1 = s1_conv2.get_shape().as_list()[1:3]
        h2, w2 = s9_deconv1.get_shape().as_list()[1:3]
        s1_conv2_crop = tf.image.crop_to_bounding_box(
            image=s1_conv2,
            offset_height=int(0.5 * (h1 - h2)),
            offset_width=int(0.5 * (w1 - w2)),
            target_height=h2,
            target_width=w2)
        tf_utils.print_activations(s1_conv2_crop, logger=self.logger)

        s9_concat = tf_utils.concat(values=[s1_conv2_crop, s9_deconv1],
                                    axis=3,
                                    name='s9_concat',
                                    logger=self.logger)
        s9_conv2 = tf_utils.conv2d(x=s9_concat,
                                   output_dim=self.conv_dims[20],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s9_conv2',
                                   logger=self.logger)
        s9_conv2 = tf_utils.relu(s9_conv2,
                                 name='relu_s9_conv2',
                                 logger=self.logger)
        s9_conv3 = tf_utils.conv2d(x=s9_conv2,
                                   output_dim=self.conv_dims[21],
                                   k_h=3,
                                   k_w=3,
                                   d_h=1,
                                   d_w=1,
                                   padding='VALID',
                                   initializer='He',
                                   name='s9_conv3',
                                   logger=self.logger)
        s9_conv3 = tf_utils.relu(s9_conv3,
                                 name='relu_s9_conv3',
                                 logger=self.logger)
        self.pred = tf_utils.conv2d(x=s9_conv3,
                                    output_dim=self.conv_dims[22],
                                    k_h=1,
                                    k_w=1,
                                    d_h=1,
                                    d_w=1,
                                    padding='SAME',
                                    initializer='He',
                                    name='output',
                                    logger=self.logger)
示例#27
0
    def forward_network(self, img, padding='SAME', reuse=False):
        with tf.compat.v1.variable_scope(self.name, reuse=reuse):
            # Stage 0
            s0_conv1 = tf_utils.conv2d(x=img,
                                       output_dim=self.conv_dims[0],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s0_conv1')
            s0_conv1 = tf_utils.relu(s0_conv1, name='relu_s0_conv1')

            s0_conv2 = tf_utils.conv2d(x=s0_conv1,
                                       output_dim=self.conv_dims[0],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s0_conv2')
            s0_conv2 = tf_utils.norm(s0_conv2,
                                     name='s0_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s0_conv2 = tf_utils.relu(s0_conv2, name='relu_s0_conv2')

            # Stage 1
            s1_maxpool = tf_utils.max_pool(x=s0_conv2, name='s1_maxpool2d')

            s1_conv1 = tf_utils.conv2d(x=s1_maxpool,
                                       output_dim=self.conv_dims[0],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s1_conv1')
            s1_conv1 = tf_utils.norm(s1_conv1,
                                     name='s1_norm0',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s1_conv1 = tf_utils.relu(s1_conv1, name='relu_s1_conv1')

            s1_conv2 = tf_utils.conv2d(x=s1_conv1,
                                       output_dim=self.conv_dims[1],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s1_conv2')
            s1_conv2 = tf_utils.norm(s1_conv2,
                                     name='s1_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s1_conv2 = tf_utils.relu(s1_conv2, name='relu_s1_conv2')

            # Stage 2
            s2_maxpool = tf_utils.max_pool(x=s1_conv2, name='s2_maxpool2d')
            s2_conv1 = tf_utils.conv2d(x=s2_maxpool,
                                       output_dim=self.conv_dims[2],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s2_conv1')
            s2_conv1 = tf_utils.norm(s2_conv1,
                                     name='s2_norm0',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s2_conv1 = tf_utils.relu(s2_conv1, name='relu_s2_conv1')

            s2_conv2 = tf_utils.conv2d(x=s2_conv1,
                                       output_dim=self.conv_dims[3],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s2_conv2')
            s2_conv2 = tf_utils.norm(s2_conv2,
                                     name='s2_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s2_conv2 = tf_utils.relu(s2_conv2, name='relu_s2_conv2')

            # Stage 3
            s3_maxpool = tf_utils.max_pool(x=s2_conv2, name='s3_maxpool2d')
            s3_conv1 = tf_utils.conv2d(x=s3_maxpool,
                                       output_dim=self.conv_dims[4],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s3_conv1')
            s3_conv1 = tf_utils.norm(s3_conv1,
                                     name='s3_norm0',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s3_conv1 = tf_utils.relu(s3_conv1, name='relu_s3_conv1')

            s3_conv2 = tf_utils.conv2d(x=s3_conv1,
                                       output_dim=self.conv_dims[5],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s3_conv2')
            s3_conv2 = tf_utils.norm(s3_conv2,
                                     name='s3_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s3_conv2 = tf_utils.relu(s3_conv2, name='relu_s3_conv2')

            # Stage 4
            s4_maxpool = tf_utils.max_pool(x=s3_conv2, name='s4_maxpool2d')
            s4_conv1 = tf_utils.conv2d(x=s4_maxpool,
                                       output_dim=self.conv_dims[6],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s4_conv1')
            s4_conv1 = tf_utils.norm(s4_conv1,
                                     name='s4_norm0',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s4_conv1 = tf_utils.relu(s4_conv1, name='relu_s4_conv1')

            s4_conv2 = tf_utils.conv2d(x=s4_conv1,
                                       output_dim=self.conv_dims[7],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s4_conv2')
            s4_conv2 = tf_utils.norm(s4_conv2,
                                     name='s4_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s4_conv2 = tf_utils.relu(s4_conv2, name='relu_s4_conv2')
            s4_conv2_drop = tf_utils.dropout(x=s4_conv2,
                                             keep_prob=0.,
                                             name='s4_dropout')

            # Stage 5
            s5_maxpool = tf_utils.max_pool(x=s4_conv2_drop,
                                           name='s5_maxpool2d')
            s5_conv1 = tf_utils.conv2d(x=s5_maxpool,
                                       output_dim=self.conv_dims[8],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s5_conv1')
            s5_conv1 = tf_utils.norm(s5_conv1,
                                     name='s5_norm0',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s5_conv1 = tf_utils.relu(s5_conv1, name='relu_s5_conv1')

            s5_conv2 = tf_utils.conv2d(x=s5_conv1,
                                       output_dim=self.conv_dims[9],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s5_conv2')
            s5_conv2 = tf_utils.norm(s5_conv2,
                                     name='s5_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s5_conv2 = tf_utils.relu(s5_conv2, name='relu_s5_conv2')
            s5_conv2_drop = tf_utils.dropout(x=s5_conv2,
                                             keep_prob=0.,
                                             name='s5_dropout')

            # Stage 6
            s6_deconv1 = tf_utils.deconv2d(x=s5_conv2_drop,
                                           output_dim=self.conv_dims[10],
                                           k_h=2,
                                           k_w=2,
                                           initializer='He',
                                           name='s6_deconv1')
            s6_deconv1 = tf_utils.norm(s6_deconv1,
                                       name='s6_norm0',
                                       _type='batch',
                                       _ops=self._ops,
                                       is_train=False)
            s6_deconv1 = tf_utils.relu(s6_deconv1, name='relu_s6_deconv1')
            # Cropping
            w1 = s4_conv2_drop.get_shape().as_list()[2]
            w2 = s6_deconv1.get_shape().as_list()[2] - s4_conv2_drop.get_shape(
            ).as_list()[2]
            s6_deconv1_split, _ = tf.split(s6_deconv1,
                                           num_or_size_splits=[w1, w2],
                                           axis=2,
                                           name='axis2_split')
            tf_utils.print_activations(s6_deconv1_split)
            # Concat
            s6_concat = tf_utils.concat(
                values=[s6_deconv1_split, s4_conv2_drop],
                axis=3,
                name='s6_axis3_concat')

            s6_conv2 = tf_utils.conv2d(x=s6_concat,
                                       output_dim=self.conv_dims[11],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s6_conv2')
            s6_conv2 = tf_utils.norm(s6_conv2,
                                     name='s6_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s6_conv2 = tf_utils.relu(s6_conv2, name='relu_s6_conv2')

            s6_conv3 = tf_utils.conv2d(x=s6_conv2,
                                       output_dim=self.conv_dims[12],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s6_conv3')
            s6_conv3 = tf_utils.norm(s6_conv3,
                                     name='s6_norm2',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s6_conv3 = tf_utils.relu(s6_conv3, name='relu_s6_conv3')

            # Stage 7
            s7_deconv1 = tf_utils.deconv2d(x=s6_conv3,
                                           output_dim=self.conv_dims[13],
                                           k_h=2,
                                           k_w=2,
                                           initializer='He',
                                           name='s7_deconv1')
            s7_deconv1 = tf_utils.norm(s7_deconv1,
                                       name='s7_norm0',
                                       _type='batch',
                                       _ops=self._ops,
                                       is_train=False)
            s7_deconv1 = tf_utils.relu(s7_deconv1, name='relu_s7_deconv1')
            # Concat
            s7_concat = tf_utils.concat(values=[s7_deconv1, s3_conv2],
                                        axis=3,
                                        name='s7_axis3_concat')

            s7_conv2 = tf_utils.conv2d(x=s7_concat,
                                       output_dim=self.conv_dims[14],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s7_conv2')
            s7_conv2 = tf_utils.norm(s7_conv2,
                                     name='s7_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s7_conv2 = tf_utils.relu(s7_conv2, name='relu_s7_conv2')

            s7_conv3 = tf_utils.conv2d(x=s7_conv2,
                                       output_dim=self.conv_dims[15],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s7_conv3')
            s7_conv3 = tf_utils.norm(s7_conv3,
                                     name='s7_norm2',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s7_conv3 = tf_utils.relu(s7_conv3, name='relu_s7_conv3')

            # Stage 8
            s8_deconv1 = tf_utils.deconv2d(x=s7_conv3,
                                           output_dim=self.conv_dims[16],
                                           k_h=2,
                                           k_w=2,
                                           initializer='He',
                                           name='s8_deconv1')
            s8_deconv1 = tf_utils.norm(s8_deconv1,
                                       name='s8_norm0',
                                       _type='batch',
                                       _ops=self._ops,
                                       is_train=False)
            s8_deconv1 = tf_utils.relu(s8_deconv1, name='relu_s8_deconv1')
            # Concat
            s8_concat = tf_utils.concat(values=[s8_deconv1, s2_conv2],
                                        axis=3,
                                        name='s8_axis3_concat')

            s8_conv2 = tf_utils.conv2d(x=s8_concat,
                                       output_dim=self.conv_dims[17],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s8_conv2')
            s8_conv2 = tf_utils.norm(s8_conv2,
                                     name='s8_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s8_conv2 = tf_utils.relu(s8_conv2, name='relu_s8_conv2')

            s8_conv3 = tf_utils.conv2d(x=s8_conv2,
                                       output_dim=self.conv_dims[18],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s8_conv3')
            s8_conv3 = tf_utils.norm(s8_conv3,
                                     name='s8_norm2',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s8_conv3 = tf_utils.relu(s8_conv3, name='relu_conv3')

            # Stage 9
            s9_deconv1 = tf_utils.deconv2d(x=s8_conv3,
                                           output_dim=self.conv_dims[19],
                                           k_h=2,
                                           k_w=2,
                                           initializer='He',
                                           name='s9_deconv1')
            s9_deconv1 = tf_utils.norm(s9_deconv1,
                                       name='s9_norm0',
                                       _type='batch',
                                       _ops=self._ops,
                                       is_train=False)
            s9_deconv1 = tf_utils.relu(s9_deconv1, name='relu_s9_deconv1')
            # Concat
            s9_concat = tf_utils.concat(values=[s9_deconv1, s1_conv2],
                                        axis=3,
                                        name='s9_axis3_concat')

            s9_conv2 = tf_utils.conv2d(x=s9_concat,
                                       output_dim=self.conv_dims[20],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s9_conv2')
            s9_conv2 = tf_utils.norm(s9_conv2,
                                     name='s9_norm1',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s9_conv2 = tf_utils.relu(s9_conv2, name='relu_s9_conv2')

            s9_conv3 = tf_utils.conv2d(x=s9_conv2,
                                       output_dim=self.conv_dims[21],
                                       k_h=3,
                                       k_w=3,
                                       d_h=1,
                                       d_w=1,
                                       padding=padding,
                                       initializer='He',
                                       name='s9_conv3')
            s9_conv3 = tf_utils.norm(s9_conv3,
                                     name='s9_norm2',
                                     _type='batch',
                                     _ops=self._ops,
                                     is_train=False)
            s9_conv3 = tf_utils.relu(s9_conv3, name='relu_s9_conv3')

            s10_deconv1 = tf_utils.deconv2d(x=s9_conv3,
                                            output_dim=self.conv_dims[-1],
                                            k_h=2,
                                            k_w=2,
                                            initializer='He',
                                            name='s10_deconv1')
            s10_deconv1 = tf_utils.norm(s10_deconv1,
                                        name='s10_norm0',
                                        _type='batch',
                                        _ops=self._ops,
                                        is_train=False)
            s10_deconv1 = tf_utils.relu(s10_deconv1, name='relu_s10_deconv1')
            # Concat
            s10_concat = tf_utils.concat(values=[s10_deconv1, s0_conv2],
                                         axis=3,
                                         name='s10_axis3_concat')

            s10_conv2 = tf_utils.conv2d(s10_concat,
                                        output_dim=self.conv_dims[-1],
                                        k_h=3,
                                        k_w=3,
                                        d_h=1,
                                        d_w=1,
                                        padding=padding,
                                        initializer='He',
                                        name='s10_conv2')
            s10_conv2 = tf_utils.norm(s10_conv2,
                                      name='s10_norm1',
                                      _type='batch',
                                      _ops=self._ops,
                                      is_train=False)
            s10_conv2 = tf_utils.relu(s10_conv2, name='relu_s10_conv2')

            s10_conv3 = tf_utils.conv2d(x=s10_conv2,
                                        output_dim=self.conv_dims[-1],
                                        k_h=3,
                                        k_w=3,
                                        d_h=1,
                                        d_w=1,
                                        padding=padding,
                                        initializer='He',
                                        name='s10_conv3')
            s10_conv3 = tf_utils.norm(s10_conv3,
                                      name='s10_norm2',
                                      _type='batch',
                                      _ops=self._ops,
                                      is_train=False)
            s10_conv3 = tf_utils.relu(s10_conv3, name='relu_s10_conv3')

            output = tf_utils.conv2d(s10_conv3,
                                     output_dim=self.num_classes,
                                     k_h=1,
                                     k_w=1,
                                     d_h=1,
                                     d_w=1,
                                     padding=padding,
                                     initializer='He',
                                     name='output')

            return output
    def __call__(self, x):
        with tf.compat.v1.variable_scope(self.name, reuse=self.reuse):
            tf_utils.print_activations(x, logger=self.logger)

            # H1: (320, 200) -> (160, 100)
            h0_conv2d = tf_utils.conv2d(x,
                                        output_dim=self.dis_c[0],
                                        initializer='He',
                                        logger=self.logger,
                                        name='h0_conv2d')
            h0_lrelu = tf_utils.lrelu(h0_conv2d,
                                      logger=self.logger,
                                      name='h0_lrelu')

            # H2: (160, 100) -> (80, 50)
            h1_conv2d = tf_utils.conv2d(h0_lrelu,
                                        output_dim=self.dis_c[1],
                                        initializer='He',
                                        logger=self.logger,
                                        name='h1_conv2d')
            h1_norm = tf_utils.norm(h1_conv2d,
                                    _type=self.norm,
                                    _ops=self._ops,
                                    logger=self.logger,
                                    name='h1_norm')
            h1_lrelu = tf_utils.lrelu(h1_norm,
                                      logger=self.logger,
                                      name='h1_lrelu')

            # H3: (80, 50) -> (40, 25)
            h2_conv2d = tf_utils.conv2d(h1_lrelu,
                                        output_dim=self.dis_c[2],
                                        initializer='He',
                                        logger=self.logger,
                                        name='h2_conv2d')
            h2_norm = tf_utils.norm(h2_conv2d,
                                    _type=self.norm,
                                    _ops=self._ops,
                                    logger=self.logger,
                                    name='h2_norm')
            h2_lrelu = tf_utils.lrelu(h2_norm,
                                      logger=self.logger,
                                      name='h2_lrelu')

            # H4: (40, 25) -> (20, 13)
            h3_conv2d = tf_utils.conv2d(h2_lrelu,
                                        output_dim=self.dis_c[3],
                                        initializer='He',
                                        logger=self.logger,
                                        name='h3_conv2d')
            h3_norm = tf_utils.norm(h3_conv2d,
                                    _type=self.norm,
                                    _ops=self._ops,
                                    logger=self.logger,
                                    name='h3_norm')
            h3_lrelu = tf_utils.lrelu(h3_norm,
                                      logger=self.logger,
                                      name='h3_lrelu')

            # H5: (20, 13) -> (20, 13)
            output = tf_utils.conv2d(h3_lrelu,
                                     output_dim=self.dis_c[4],
                                     k_h=3,
                                     k_w=3,
                                     d_h=1,
                                     d_w=1,
                                     initializer='He',
                                     logger=self.logger,
                                     name='output_conv2d')

            # set reuse=True for next call
            self.reuse = True
            self.variables = tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)

        return output
    def forward_network(self, input_img, padding='SAME', reuse=False):
        with tf.compat.v1.variable_scope(self.name, reuse=reuse):
            tf_utils.print_activations(input_img, logger=self.logger)
            inputs = self.conv2d_fixed_padding(inputs=input_img,
                                               filters=64,
                                               kernel_size=7,
                                               strides=2,
                                               name='conv1')
            inputs = tf_utils.max_pool(inputs,
                                       name='3x3_maxpool',
                                       ksize=[1, 3, 3, 1],
                                       strides=[1, 2, 2, 1],
                                       logger=self.logger)

            inputs = self.block_layer(inputs=inputs,
                                      filters=64,
                                      block_fn=self.bottleneck_block,
                                      blocks=self.layers[0],
                                      strides=1,
                                      train_mode=self.train_mode,
                                      name='block_layer1')
            inputs = self.block_layer(inputs=inputs,
                                      filters=128,
                                      block_fn=self.bottleneck_block,
                                      blocks=self.layers[1],
                                      strides=2,
                                      train_mode=self.train_mode,
                                      name='block_layer2')
            inputs = self.block_layer(inputs=inputs,
                                      filters=256,
                                      block_fn=self.bottleneck_block,
                                      blocks=self.layers[2],
                                      strides=2,
                                      train_mode=self.train_mode,
                                      name='block_layer3')
            inputs = self.block_layer(inputs=inputs,
                                      filters=512,
                                      block_fn=self.bottleneck_block,
                                      blocks=self.layers[3],
                                      strides=2,
                                      train_mode=self.train_mode,
                                      name='block_layer4')

            inputs = tf_utils.norm(inputs,
                                   name='before_gap_batch_norm',
                                   _type='batch',
                                   _ops=self._ops,
                                   is_train=self.train_mode,
                                   logger=self.logger)
            inputs = tf_utils.relu(inputs,
                                   name='before_gap_relu',
                                   logger=self.logger)
            _, h, w, _ = inputs.get_shape().as_list()
            inputs = tf_utils.avg_pool(inputs,
                                       name='gap',
                                       ksize=[1, h, w, 1],
                                       strides=[1, 1, 1, 1],
                                       logger=self.logger)

            inputs = tf_utils.flatten(inputs,
                                      name='flatten',
                                      logger=self.logger)
            logits = tf_utils.linear(inputs, self.num_classes, name='logits')

            return logits
示例#30
0
    def __call__(self, x):
        with tf.variable_scope(self.name, reuse=self.reuse):
            tf_utils.print_activations(x)

            # conv: (N, H, W, C) -> (N, H/2, W/2, 64)
            output = tf_utils.conv2d(x,
                                     self.conv_dims[0],
                                     k_h=4,
                                     k_w=4,
                                     d_h=2,
                                     d_w=2,
                                     padding='SAME',
                                     name='conv0_conv2d')
            output = tf_utils.lrelu(output, name='conv0_lrelu', is_print=True)

            for idx, conv_dim in enumerate(self.conv_dims[1:]):
                # conv: (N, H/2, W/2, C) -> (N, H/4, W/4, 2C)
                output = tf_utils.conv2d(output,
                                         conv_dim,
                                         k_h=4,
                                         k_w=4,
                                         d_h=2,
                                         d_w=2,
                                         padding='SAME',
                                         name='conv{}_conv2d'.format(idx + 1))
                output = tf_utils.norm(output,
                                       _type=self.norm,
                                       _ops=self._ops,
                                       name='conv{}_norm'.format(idx + 1))
                output = tf_utils.lrelu(output,
                                        name='conv{}_lrelu'.format(idx + 1),
                                        is_print=True)

            for idx, deconv_dim in enumerate(self.deconv_dims):
                # deconv: (N, H/16, W/16, C) -> (N, W/8, H/8, C/2)
                output = tf_utils.deconv2d(output,
                                           deconv_dim,
                                           k_h=4,
                                           k_w=4,
                                           name='deconv{}_conv2d'.format(idx))
                output = tf_utils.norm(output,
                                       _type=self.norm,
                                       _ops=self._ops,
                                       name='deconv{}_norm'.format(idx))
                output = tf_utils.relu(output,
                                       name='deconv{}_relu'.format(idx),
                                       is_print=True)

            # split (N, 152, 104, 64) to (N, 150, 104, 64)
            shapeA = int(self.img_size[0] / 2)
            shapeB = output.get_shape().as_list()[1] - shapeA
            output, _ = tf.split(output, [shapeA, shapeB],
                                 axis=1,
                                 name='split_0')
            tf_utils.print_activations(output)
            # split (N, 150, 104, 64) to (N, 150, 100, 64)
            shapeA = int(self.img_size[1] / 2)
            shapeB = output.get_shape().as_list()[2] - shapeA
            output, _ = tf.split(output, [shapeA, shapeB],
                                 axis=2,
                                 name='split_1')
            tf_utils.print_activations(output)

            # conv: (N, H/2, W/2, 64) -> (N, W, H, 3)
            output = tf_utils.deconv2d(output,
                                       self.img_size[2],
                                       k_h=4,
                                       k_w=4,
                                       name='conv3_deconv2d')
            output = tf_utils.tanh(output, name='conv4_tanh', is_print=True)

            # set reuse=True for next call
            self.reuse = True
            self.variables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
            return output