def decoder(self, z, name='decoder', is_reuse=False): with tf.variable_scope(name) as scope: if is_reuse is True: scope.reuse_variables() tf_utils.print_activations(z) # 1st hidden layer h0_linear = tf_utils.linear(z, self.n_hidden, name='h0_linear') h0_tanh = tf_utils.tanh(h0_linear, name='h0_tanh') h0_drop = tf.nn.dropout(h0_tanh, keep_prob=self.keep_prob_tfph, name='h0_drop') tf_utils.print_activations(h0_drop) # 2nd hidden layer h1_linear = tf_utils.linear(h0_drop, self.n_hidden, name='h1_linear') h1_elu = tf_utils.elu(h1_linear, name='h1_elu') h1_drop = tf.nn.dropout(h1_elu, keep_prob=self.keep_prob_tfph, name='h1_drop') tf_utils.print_activations(h1_drop) # 3rd hidden layer h2_linear = tf_utils.linear(h1_drop, self.output_dim, name='h2_linear') h2_sigmoid = tf_utils.sigmoid(h2_linear, name='h2_sigmoid') tf_utils.print_activations(h2_sigmoid) output = tf.reshape(h2_sigmoid, [-1, *self.image_size]) tf_utils.print_activations(output) return output
def bottleneck_block(self, inputs, filters, train_mode, projection_shortcut, strides, name): with tf.compat.v1.variable_scope(name): shortcut = inputs inputs = tf_utils.relu(inputs, name='relu_0', logger=None) # The projection shortcut shouldcome after the first batch norm and ReLU since it perofrms a 1x1 convolution. if projection_shortcut is not None: shortcut = self.projection_shortcut(inputs=inputs, filters_out=filters, strides=strides, name='conv_projection') inputs = self.conv2d_fixed_padding(inputs=inputs, filters=filters, kernel_size=3, strides=strides, name='conv_0') inputs = tf_utils.relu(inputs, name='relu_1', logger=None) inputs = self.conv2d_fixed_padding(inputs=inputs, filters=filters, kernel_size=3, strides=1, name='conv_1') output = tf.identity(inputs + shortcut, name=(name + '_output')) tf_utils.print_activations(output, logger=None) return output
def model_g(self, x): with tf.variable_scope(self.name, reuse=self.reuse): tf_utils.print_activations(x) # (N, H, W, C) -> (N, H/2, W/2, 64) conv1 = tf_utils.conv2d(x, self.ndf, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv1_conv') conv1 = tf_utils.lrelu(conv1, name='conv1_lrelu', is_print=True) # (N, H/2, W/2, 64) -> (N, H/4, W/4, 128) conv2 = tf_utils.conv_norm_lrelu(conv1, 2 * self.ndf, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv2_conv', ops=self._ops) # (N, H/4, W/4, 128) -> (N, H/8, W/8, 256) conv3 = tf_utils.conv_norm_lrelu(conv2, 4 * self.ndf, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv3_conv', ops=self._ops) # (N, H/8, W/8, 256) -> (N, H/16, W/16, 512) conv4 = tf_utils.conv2d(conv3, 8 * self.ndf, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv4_conv', ops=self._ops) # (N, H/16, W/16, 512) -> (N, H/16, W/16, 1) conv5 = tf_utils.conv2d(conv4, 1, k_h=4, k_w=4, d_h=1, d_w=1, padding='SAME', name='conv5_conv', is_print=True) output = tf.identity(conv5, name='output_without_sigmoid') # set reuse=True for next call self.reuse = True self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) return output
def forward_network(self, input_img, reuse=False): with tf.compat.v1.variable_scope(self.name, reuse=reuse): tf_utils.print_activations(input_img, logger=None) inputs = self.conv2d_fixed_padding(inputs=input_img, filters=64, kernel_size=7, strides=2, name='conv1') inputs = tf_utils.max_pool(inputs, name='3x3_maxpool', ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], logger=None) inputs = self.block_layer(inputs=inputs, filters=64, block_fn=self.bottleneck_block, blocks=self.layers[0], strides=1, train_mode=False, name='block_layer1') inputs = self.block_layer(inputs=inputs, filters=128, block_fn=self.bottleneck_block, blocks=self.layers[1], strides=2, train_mode=False, name='block_layer2') inputs = self.block_layer(inputs=inputs, filters=256, block_fn=self.bottleneck_block, blocks=self.layers[2], strides=2, train_mode=False, name='block_layer3') inputs = self.block_layer(inputs=inputs, filters=512, block_fn=self.bottleneck_block, blocks=self.layers[3], strides=2, train_mode=False, name='block_layer4') inputs = tf_utils.relu(inputs, name='before_flatten_relu', logger=None) # _, h, w, _ = inputs.get_shape().as_list() # inputs = tf_utils.avg_pool(inputs, name='gap', ksize=[1, h, w, 1], strides=[1, 1, 1, 1], logger=self.logger) # Flatten & FC1 inputs = tf_utils.flatten(inputs, name='flatten', logger=None) inputs = tf_utils.linear(inputs, 512, name='FC1') inputs = tf_utils.relu(inputs, name='FC1_relu', logger=None) inputs = tf_utils.linear(inputs, 256, name='FC2') inputs = tf_utils.relu(inputs, name='FC2_relu', logger=None) logits = tf_utils.linear(inputs, self.num_attribute, name='Out') return logits
def basicDiscriminator(self, data, name='d_', is_reuse=False): with tf.variable_scope(name) as scope: if is_reuse is True: scope.reuse_variables() tf_utils.print_activations(data) # from (N, 32, 32, 1) to (N, 16, 16, 64) h0_conv = tf_utils.conv2d(data, self.dis_c[0], k_h=5, k_w=5, name='h0_conv2d') h0_lrelu = tf_utils.lrelu(h0_conv, name='h0_lrelu') # from (N, 16, 16, 64) to (N, 8, 8, 128) h1_conv = tf_utils.conv2d(h0_lrelu, self.dis_c[1], k_h=5, k_w=5, name='h1_conv2d') h1_lrelu = tf_utils.lrelu(h1_conv, name='h1_lrelu') # from (N, 8, 8, 128) to (N, 4, 4, 256) h2_conv = tf_utils.conv2d(h1_lrelu, self.dis_c[2], k_h=5, k_w=5, name='h2_conv2d') h2_lrelu = tf_utils.lrelu(h2_conv, name='h2_lrelu') # from (N, 4, 4, 256) to (N, 4096) and to (N, 1) h2_flatten = flatten(h2_lrelu) h3_linear = tf_utils.linear(h2_flatten, 1, name='h3_linear') return tf.nn.sigmoid(h3_linear), h3_linear
def basicGenerator(self, data, name='g_'): with tf.variable_scope(name): data_flatten = flatten(data) tf_utils.print_activations(data_flatten) # from (N, 128) to (N, 4, 4, 256) h0_linear = tf_utils.linear(data_flatten, self.gen_c[0], name='h0_linear') if self.flags.dataset == 'cifar10': h0_linear = tf.reshape(h0_linear, [ tf.shape(h0_linear)[0], 4, 4, int(self.gen_c[0] / (4 * 4)) ]) h0_linear = tf_utils.norm(h0_linear, _type='batch', _ops=self.gen_train_ops, name='h0_norm') h0_relu = tf.nn.relu(h0_linear, name='h0_relu') h0_reshape = tf.reshape( h0_relu, [tf.shape(h0_relu)[0], 4, 4, int(self.gen_c[0] / (4 * 4))]) # from (N, 4, 4, 256) to (N, 8, 8, 128) h1_deconv = tf_utils.deconv2d(h0_reshape, self.gen_c[1], k_h=5, k_w=5, name='h1_deconv2d') if self.flags.dataset == 'cifar10': h1_deconv = tf_utils.norm(h1_deconv, _type='batch', _ops=self.gen_train_ops, name='h1_norm') h1_relu = tf.nn.relu(h1_deconv, name='h1_relu') # from (N, 8, 8, 128) to (N, 16, 16, 64) h2_deconv = tf_utils.deconv2d(h1_relu, self.gen_c[2], k_h=5, k_w=5, name='h2_deconv2d') if self.flags.dataset == 'cifar10': h2_deconv = tf_utils.norm(h2_deconv, _type='batch', _ops=self.gen_train_ops, name='h2_norm') h2_relu = tf.nn.relu(h2_deconv, name='h2_relu') # from (N, 16, 16, 64) to (N, 32, 32, 1) output = tf_utils.deconv2d(h2_relu, self.image_size[2], k_h=5, k_w=5, name='h3_deconv2d') return tf_utils.tanh(output)
def __call__(self, x, mode=1): with tf.variable_scope(self.name, reuse=self.reuse): x = tf.concat([x, x, x], axis=-1, name='concat') tf_utils.print_activations(x) # conv1 relu1_1 = self.conv_layer(x, 'conv1_1', trainable=False) relu1_2 = self.conv_layer(relu1_1, 'conv1_2', trainable=False) pool_1 = tf_utils.max_pool_2x2(relu1_2, name='max_pool_1') tf_utils.print_activations(pool_1) # conv2 relu2_1 = self.conv_layer(pool_1, 'conv2_1', trainable=False) relu2_2 = self.conv_layer(relu2_1, 'conv2_2', trainable=False) pool_2 = tf_utils.max_pool_2x2(relu2_2, name='max_pool_2') tf_utils.print_activations(pool_2) # conv3 relu3_1 = self.conv_layer(pool_2, 'conv3_1', trainable=False) relu3_2 = self.conv_layer(relu3_1, 'conv3_2', trainable=False) relu3_3 = self.conv_layer(relu3_2, 'conv3_3', trainable=False) pool_3 = tf_utils.max_pool_2x2(relu3_3, name='max_pool_3') tf_utils.print_activations(pool_3) # conv4 relu4_1 = self.conv_layer(pool_3, 'conv4_1', trainable=False) relu4_2 = self.conv_layer(relu4_1, 'conv4_2', trainable=False) relu4_3 = self.conv_layer(relu4_2, 'conv4_3', trainable=False) pool_4 = tf_utils.max_pool_2x2(relu4_3, name='max_pool_4') tf_utils.print_activations(pool_4) # conv5 relu5_1 = self.conv_layer(pool_4, 'conv5_1', trainable=False) relu5_2 = self.conv_layer(relu5_1, 'conv5_2', trainable=False) relu5_3 = self.conv_layer(relu5_2, 'conv5_3', trainable=False) # set reuse=True for next call self.reuse = True if mode == 1: outputs = [relu1_2] elif mode == 2: outputs = [relu1_2, relu2_2] elif mode == 3: outputs = [relu1_2, relu2_2, relu3_3] elif mode == 4: outputs = [relu1_2, relu2_2, relu3_3, relu4_3] elif mode == 5: outputs = [relu1_2, relu2_2, relu3_3, relu4_3, relu5_3] else: raise NotImplementedError return outputs
def __call__(self, x): with tf.variable_scope(self.name, reuse=self.reuse): tf_utils.print_activations(x) # (N, H, W, C) -> (N, H, W, 64) conv1 = tf_utils.padding2d(x, p_h=3, p_w=3, pad_type='REFLECT', name='conv1_padding') conv1 = tf_utils.conv2d(conv1, self.ngf, k_h=7, k_w=7, d_h=1, d_w=1, padding='VALID', name='conv1_conv') conv1 = tf_utils.norm(conv1, _type='instance', _ops=self._ops, name='conv1_norm') conv1 = tf_utils.relu(conv1, name='conv1_relu', is_print=True) # (N, H, W, 64) -> (N, H/2, W/2, 128) conv2 = tf_utils.conv2d(conv1, 2*self.ngf, k_h=3, k_w=3, d_h=2, d_w=2, padding='SAME', name='conv2_conv') conv2 = tf_utils.norm(conv2, _type='instance', _ops=self._ops, name='conv2_norm',) conv2 = tf_utils.relu(conv2, name='conv2_relu', is_print=True) # (N, H/2, W/2, 128) -> (N, H/4, W/4, 256) conv3 = tf_utils.conv2d(conv2, 4*self.ngf, k_h=3, k_w=3, d_h=2, d_w=2, padding='SAME', name='conv3_conv') conv3 = tf_utils.norm(conv3, _type='instance', _ops=self._ops, name='conv3_norm',) conv3 = tf_utils.relu(conv3, name='conv3_relu', is_print=True) # (N, H/4, W/4, 256) -> (N, H/4, W/4, 256) if (self.image_size[0] <= 128) and (self.image_size[1] <= 128): # use 6 residual blocks for 128x128 images res_out = tf_utils.n_res_blocks(conv3, num_blocks=6, is_print=True) else: # use 9 blocks for higher resolution res_out = tf_utils.n_res_blocks(conv3, num_blocks=9, is_print=True) # (N, H/4, W/4, 256) -> (N, H/2, W/2, 128) conv4 = tf_utils.deconv2d(res_out, 2*self.ngf, name='conv4_deconv2d') conv4 = tf_utils.norm(conv4, _type='instance', _ops=self._ops, name='conv4_norm') conv4 = tf_utils.relu(conv4, name='conv4_relu', is_print=True) # (N, H/2, W/2, 128) -> (N, H, W, 64) conv5 = tf_utils.deconv2d(conv4, self.ngf, name='conv5_deconv2d') conv5 = tf_utils.norm(conv5, _type='instance', _ops=self._ops, name='conv5_norm') conv5 = tf_utils.relu(conv5, name='conv5_relu', is_print=True) # (N, H, W, 64) -> (N, H, W, 3) conv6 = tf_utils.padding2d(conv5, p_h=3, p_w=3, pad_type='REFLECT', name='output_padding') conv6 = tf_utils.conv2d(conv6, self.image_size[2], k_h=7, k_w=7, d_h=1, d_w=1, padding='VALID', name='output_conv') output = tf_utils.tanh(conv6, name='output_tanh', is_print=True) # set reuse=True for next call self.reuse = True self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) return output
def resnetDiscriminator(self, data, name='d_', is_reuse=False): with tf.variable_scope(name) as scope: if is_reuse is True: scope.reuse_variables() tf_utils.print_activations(data) # (N, 64, 64, 64) conv_0 = tf_utils.conv2d(data, output_dim=self.dis_c[0], k_h=3, k_w=3, d_h=1, d_w=1, name='conv_0') # (N, 32, 32, 128) resblock_1 = tf_utils.res_block_v2(conv_0, self.dis_c[1], filter_size=3, _ops=self.dis_train_ops, norm_='layer', resample='down', name='res_block_1') # (N, 16, 16, 256) resblock_2 = tf_utils.res_block_v2(resblock_1, self.dis_c[2], filter_size=3, _ops=self.dis_train_ops, norm_='layer', resample='down', name='res_block_2') # (N, 8, 8, 512) resblock_3 = tf_utils.res_block_v2(resblock_2, self.dis_c[3], filter_size=3, _ops=self.dis_train_ops, norm_='layer', resample='down', name='res_block_3') # (N, 4, 4, 512) resblock_4 = tf_utils.res_block_v2(resblock_3, self.dis_c[4], filter_size=3, _ops=self.dis_train_ops, norm_='layer', resample='down', name='res_block_4') # (N, 4*4*512) flatten_5 = flatten(resblock_4) output = tf_utils.linear(flatten_5, 1, name='output') return tf.nn.sigmoid(output), output
def __call__(self, x): with tf.variable_scope(self.name, reuse=self.reuse): tf_utils.print_activations(x) # 200 -> 100 h0_conv2d = tf_utils.conv2d(x, self.dis_c[0], name='h0_conv2d') h0_lrelu = tf_utils.lrelu(h0_conv2d, name='h0_lrelu') # 100 -> 50 h1_conv2d = tf_utils.conv2d(h0_lrelu, self.dis_c[1], name='h1_conv2d') h1_batchnorm = tf_utils.batch_norm(h1_conv2d, name='h1_batchnorm', _ops=self._ops) h1_lrelu = tf_utils.lrelu(h1_batchnorm, name='h1_lrelu') # 50 -> 25 h2_conv2d = tf_utils.conv2d(h1_lrelu, self.dis_c[2], name='h2_conv2d') h2_batchnorm = tf_utils.batch_norm(h2_conv2d, name='h2_batchnorm', _ops=self._ops) h2_lrelu = tf_utils.lrelu(h2_batchnorm, name='h2_lrelu') # 25 -> 13 h3_conv2d = tf_utils.conv2d(h2_lrelu, self.dis_c[3], name='h3_conv2d') h3_batchnorm = tf_utils.batch_norm(h3_conv2d, name='h3_batchnorm', _ops=self._ops) h3_lrelu = tf_utils.lrelu(h3_batchnorm, name='h3_lrelu') # Patch GAN: 13 -> 13 output = tf_utils.conv2d(h3_lrelu, self.dis_c[4], k_h=3, k_w=3, d_h=1, d_w=1, name='output_conv2d') # set reuse=True for next call self.reuse = True self.variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) return output
def __call__(self, x): with tf.variable_scope(self.name, reuse=self.reuse): tf_utils.print_activations(x) # conv: (N, H, W, 3) -> (N, H/2, W/2, 64) output = tf_utils.conv2d(x, self.ndf, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv0_conv2d') output = tf_utils.lrelu(output, name='conv0_lrelu', is_print=True) for idx, hidden_dim in enumerate(self.hidden_dims[1:]): # conv: (N, H/2, W/2, C) -> (N, H/4, W/4, C/2) output = tf_utils.conv2d(output, hidden_dim, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv{}_conv2d'.format(idx + 1)) output = tf_utils.norm(output, _type=self.norm, _ops=self._ops, name='conv{}_norm'.format(idx + 1)) output = tf_utils.lrelu(output, name='conv{}_lrelu'.format(idx + 1), is_print=True) # conv: (N, H/16, W/16, 512) -> (N, H/16, W/16, 1) output = tf_utils.conv2d(output, 1, k_h=4, k_w=4, d_h=1, d_w=1, padding='SAME', name='conv4_conv2d') # set reuse=True for next call self.reuse = True self.variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) return tf_utils.sigmoid(output), output
def network(self, inputs, name=None): with tf.variable_scope(name): tf_utils.print_activations(inputs) # input of main reccurent layers output = tf_utils.conv2d_mask(inputs, 2 * self.hidden_dims, [7, 7], mask_type="A", name='inputConv1') # main recurrent layers if self.flags.model == 'pixelcnn': for idx in range(self.recurrent_length): output = tf_utils.conv2d_mask( output, self.hidden_dims, [3, 3], mask_type="B", name='mainConv{}'.format(idx + 2)) output = tf_utils.relu(output, name='mainRelu{}'.format(idx + 2)) elif self.flags.model == 'diagonal_bilstm': for idx in range(self.recurrent_length): output = self.diagonal_bilstm(output, name='BiLSTM{}'.format(idx + 2)) elif self.flags.model == 'row_lstm': raise NotImplementedError else: raise NotImplementedError # output recurrent layers for idx in range(self.out_recurrent_length): output = tf_utils.conv2d_mask(output, self.hidden_dims, [1, 1], mask_type="B", name='outputConv{}'.format(idx + 1)) output = tf_utils.relu(output, name='outputRelu{}'.format(idx + 1)) # TODO: for color images, implement a 256-way softmax for each RGB channel here output = tf_utils.conv2d_mask(output, self.img_size[2], [1, 1], mask_type="B", name='outputConv3') # output = tf_utils.sigmoid(output_logits, name='output_sigmoid') return tf_utils.sigmoid(output), output
def __init__(self, input_dim, output_dim=1, optimizer=None, use_dropout=True, lr=0.001, random_seed=123, is_train=True, log_dir=None, name=None): self.name = name self.is_train = is_train self.log_dir = log_dir self.cur_lr = None self.logger, self.file_handler, self.stream_handler = utils.init_logger(log_dir=self.log_dir, name=self.name, is_train=self.is_train) with tf.variable_scope(self.name): # Placeholders for inputs self.X = tf.placeholder(dtype=tf.float32, shape=[None, input_dim], name='X') self.y = tf.placeholder(dtype=tf.float32, shape=[None, output_dim], name='y') self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') tf_utils.print_activations(self.X, logger=self.logger if self.is_train else None) # Placeholders for TensorBoard self.train_acc = tf.placeholder(tf.float32, name='train_acc') self.val_acc = tf.placeholder(tf.float32, name='val_acc') net = self.X if use_dropout: net = tf_utils.dropout(x=net, keep_prob=self.keep_prob, seed=random_seed, name='dropout', logger=self.logger if self.is_train else None) # Network, loss, and optimizer self.y_pred = tf_utils.linear(net, output_size=output_dim) tf_utils.print_activations(self.y_pred, logger=self.logger if self.is_train else None) self.loss = tf.math.reduce_mean(tf.nn.l2_loss(self.y_pred - self.y)) self.train_op, self.cur_lr = optimizer_fn(optimizer, lr=lr, loss=self.loss, name=self.name) # Accuracy etc self.y_pred_round = tf.math.round(x=self.y_pred, name='rounded_pred') accuracy = tf.equal(tf.cast(x=self.y_pred_round, dtype=tf.int32), tf.cast(x=self.y, dtype=tf.int32)) self.accuracy = tf.reduce_mean(tf.cast(x=accuracy, dtype=tf.float32)) * 100. self._tensorboard() tf_utils.show_all_variables(logger=self.logger if self.is_train else None)
def __call__(self, x): with tf.variable_scope(self.name, reuse=self.reuse): x = tf.concat([x, x, x], axis=-1, name='concat') tf_utils.print_activations(x) # conv1 relu1_1 = self.conv_layer(x, 'conv1_1', trainable=False) relu1_2 = self.conv_layer(relu1_1, 'conv1_2', trainable=False) pool_1 = tf_utils.max_pool_2x2(relu1_2, name='max_pool_1') tf_utils.print_activations(pool_1) # conv2 relu2_1 = self.conv_layer(pool_1, 'conv2_1', trainable=False) relu2_2 = self.conv_layer(relu2_1, 'conv2_2', trainable=False) pool_2 = tf_utils.max_pool_2x2(relu2_2, name='max_pool_2') tf_utils.print_activations(pool_2) # conv3 relu3_1 = self.conv_layer(pool_2, 'conv3_1', trainable=False) relu3_2 = self.conv_layer(relu3_1, 'conv3_2', trainable=False) relu3_3 = self.conv_layer(relu3_2, 'conv3_3', trainable=False) pool_3 = tf_utils.max_pool_2x2(relu3_3, name='max_pool_3') tf_utils.print_activations(pool_3) # conv4 relu4_1 = self.conv_layer(pool_3, 'conv4_1', trainable=False) relu4_2 = self.conv_layer(relu4_1, 'conv4_2', trainable=False) relu4_3 = self.conv_layer(relu4_2, 'conv4_3', trainable=False) pool_4 = tf_utils.max_pool_2x2(relu4_3, name='max_pool_4') tf_utils.print_activations(pool_4) # conv5 relu5_1 = self.conv_layer(pool_4, 'conv5_1', trainable=False) relu5_2 = self.conv_layer(relu5_1, 'conv5_2', trainable=False) relu5_3 = self.conv_layer(relu5_2, 'conv5_3', trainable=False) # set reuse=True for next call self.reuse = True return relu5_3
def conv_layer(self, bottom, name, trainable=False): with tf.variable_scope(name): w = self.get_conv_weight(name) b = self.get_bias(name) conv_weights = tf.get_variable( "W", shape=w.shape, initializer=tf.constant_initializer(w), trainable=trainable) conv_biases = tf.get_variable( "b", shape=b.shape, initializer=tf.constant_initializer(b), trainable=trainable) conv = tf.nn.conv2d(bottom, conv_weights, [1, 1, 1, 1], padding='SAME') bias = tf.nn.bias_add(conv, conv_biases) relu = tf.nn.relu(bias) tf_utils.print_activations(relu) return relu
def __init__(self, input_dim, output_dim=[1000, 1000, 10], optimizer=None, use_dropout=True, lr=0.001, weight_decay=1e-4, random_seed=123, is_train=True, log_dir=None, name=None): self.name = name self.is_train = is_train self.log_dir = log_dir self.cur_lr = None self.logger, self.file_handler, self.stream_handler = utils.init_logger(log_dir=self.log_dir, name=self.name, is_train=self.is_train) with tf.variable_scope(self.name): # Placeholders for inputs self.X = tf.placeholder(dtype=tf.float32, shape=[None, input_dim], name='X') tf_utils.print_activations(self.X, logger=self.logger if self.is_train else None) self.y = tf.placeholder(dtype=tf.float32, shape=[None, output_dim[-1]], name='y') self.y_cls = tf.math.argmax(input=self.y, axis=1) self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') # Placeholders for TensorBoard self.train_acc = tf.placeholder(tf.float32, name='train_acc') self.val_acc = tf.placeholder(tf.float32, name='val_acc') net = self.X for idx in range(len(output_dim) - 1): net = tf_utils.linear(x=net, output_size=output_dim[idx], name='fc'+str(idx), logger=self.logger if self.is_train else None) if use_dropout: net = tf_utils.dropout(x=net, keep_prob=self.keep_prob, seed=random_seed, name='dropout'+str(idx), logger=self.logger if self.is_train else None) net = tf_utils.relu(x=net, name='relu'+str(idx), logger=self.logger if self.is_train else None) # Last predict layer self.y_pred = tf_utils.linear(net, output_size=output_dim[-1], name='last_fc') tf_utils.print_activations(self.y_pred, logger=self.logger if self.is_train else None) # Loss = data loss + regularization term self.data_loss = tf.math.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=self.y_pred, labels=self.y)) self.reg_term = weight_decay * tf.reduce_sum( [tf.nn.l2_loss(weight) for weight in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)]) self.loss = self.data_loss + self.reg_term # Optimizer self.train_op, self.cur_lr = optimizer_fn(optimizer, lr=lr, loss=self.loss, name=self.name) # Accuracy etc self.y_pred_cls = tf.math.argmax(input=self.y_pred, axis=1) correct_prediction = tf.math.equal(self.y_pred_cls, self.y_cls) self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype=tf.float32)) * 100. self._tensorboard() tf_utils.show_all_variables(logger=self.logger if self.is_train else None)
def __call__(self, x): with tf.variable_scope(self.name, reuse=self.reuse): tf_utils.print_activations(x) # (300, 200) -> (150, 100) e0_conv2d = tf_utils.conv2d(x, self.gen_c[0], name='e0_conv2d') e0_lrelu = tf_utils.lrelu(e0_conv2d, name='e0_lrelu') # (150, 100) -> (75, 50) e1_conv2d = tf_utils.conv2d(e0_lrelu, self.gen_c[1], name='e1_conv2d') e1_batchnorm = tf_utils.batch_norm(e1_conv2d, name='e1_batchnorm', _ops=self._ops) e1_lrelu = tf_utils.lrelu(e1_batchnorm, name='e1_lrelu') # (75, 50) -> (38, 25) e2_conv2d = tf_utils.conv2d(e1_lrelu, self.gen_c[2], name='e2_conv2d') e2_batchnorm = tf_utils.batch_norm(e2_conv2d, name='e2_batchnorm', _ops=self._ops) e2_lrelu = tf_utils.lrelu(e2_batchnorm, name='e2_lrelu') # (38, 25) -> (19, 13) e3_conv2d = tf_utils.conv2d(e2_lrelu, self.gen_c[3], name='e3_conv2d') e3_batchnorm = tf_utils.batch_norm(e3_conv2d, name='e3_batchnorm', _ops=self._ops) e3_lrelu = tf_utils.lrelu(e3_batchnorm, name='e3_lrelu') # (19, 13) -> (10, 7) e4_conv2d = tf_utils.conv2d(e3_lrelu, self.gen_c[4], name='e4_conv2d') e4_batchnorm = tf_utils.batch_norm(e4_conv2d, name='e4_batchnorm', _ops=self._ops) e4_lrelu = tf_utils.lrelu(e4_batchnorm, name='e4_lrelu') # (10, 7) -> (5, 4) e5_conv2d = tf_utils.conv2d(e4_lrelu, self.gen_c[5], name='e5_conv2d') e5_batchnorm = tf_utils.batch_norm(e5_conv2d, name='e5_batchnorm', _ops=self._ops) e5_lrelu = tf_utils.lrelu(e5_batchnorm, name='e5_lrelu') # (5, 4) -> (3, 2) e6_conv2d = tf_utils.conv2d(e5_lrelu, self.gen_c[6], name='e6_conv2d') e6_batchnorm = tf_utils.batch_norm(e6_conv2d, name='e6_batchnorm', _ops=self._ops) e6_lrelu = tf_utils.lrelu(e6_batchnorm, name='e6_lrelu') # (3, 2) -> (2, 1) e7_conv2d = tf_utils.conv2d(e6_lrelu, self.gen_c[7], name='e7_conv2d') e7_batchnorm = tf_utils.batch_norm(e7_conv2d, name='e7_batchnorm', _ops=self._ops) e7_relu = tf_utils.relu(e7_batchnorm, name='e7_relu') # (2, 1) -> (4, 2) d0_deconv = tf_utils.deconv2d(e7_relu, self.gen_c[8], name='d0_deconv2d') shapeA = e6_conv2d.get_shape().as_list()[1] shapeB = d0_deconv.get_shape().as_list()[1] - e6_conv2d.get_shape( ).as_list()[1] # (4, 2) -> (3, 2) d0_split, _ = tf.split(d0_deconv, [shapeA, shapeB], axis=1, name='d0_split') tf_utils.print_activations(d0_split) d0_batchnorm = tf_utils.batch_norm(d0_split, name='d0_batchnorm', _ops=self._ops) d0_drop = tf.nn.dropout(d0_batchnorm, keep_prob=0.5, name='d0_dropout') d0_concat = tf.concat([d0_drop, e6_batchnorm], axis=3, name='d0_concat') d0_relu = tf_utils.relu(d0_concat, name='d0_relu') # (3, 2) -> (6, 4) d1_deconv = tf_utils.deconv2d(d0_relu, self.gen_c[9], name='d1_deconv2d') # (6, 4) -> (5, 4) shapeA = e5_batchnorm.get_shape().as_list()[1] shapeB = d1_deconv.get_shape().as_list( )[1] - e5_batchnorm.get_shape().as_list()[1] d1_split, _ = tf.split(d1_deconv, [shapeA, shapeB], axis=1, name='d1_split') tf_utils.print_activations(d1_split) d1_batchnorm = tf_utils.batch_norm(d1_split, name='d1_batchnorm', _ops=self._ops) d1_drop = tf.nn.dropout(d1_batchnorm, keep_prob=0.5, name='d1_dropout') d1_concat = tf.concat([d1_drop, e5_batchnorm], axis=3, name='d1_concat') d1_relu = tf_utils.relu(d1_concat, name='d1_relu') # (5, 4) -> (10, 8) d2_deconv = tf_utils.deconv2d(d1_relu, self.gen_c[10], name='d2_deconv2d') # (10, 8) -> (10, 7) shapeA = e4_batchnorm.get_shape().as_list()[2] shapeB = d2_deconv.get_shape().as_list( )[2] - e4_batchnorm.get_shape().as_list()[2] d2_split, _ = tf.split(d2_deconv, [shapeA, shapeB], axis=2, name='d2_split') tf_utils.print_activations(d2_split) d2_batchnorm = tf_utils.batch_norm(d2_split, name='d2_batchnorm', _ops=self._ops) d2_drop = tf.nn.dropout(d2_batchnorm, keep_prob=0.5, name='d2_dropout') d2_concat = tf.concat([d2_drop, e4_batchnorm], axis=3, name='d2_concat') d2_relu = tf_utils.relu(d2_concat, name='d2_relu') # (10, 7) -> (20, 14) d3_deconv = tf_utils.deconv2d(d2_relu, self.gen_c[11], name='d3_deconv2d') # (20, 14) -> (19, 14) shapeA = e3_batchnorm.get_shape().as_list()[1] shapeB = d3_deconv.get_shape().as_list( )[1] - e3_batchnorm.get_shape().as_list()[1] d3_split_1, _ = tf.split(d3_deconv, [shapeA, shapeB], axis=1, name='d3_split_1') tf_utils.print_activations(d3_split_1) # (19, 14) -> (19, 13) shapeA = e3_batchnorm.get_shape().as_list()[2] shapeB = d3_split_1.get_shape().as_list( )[2] - e3_batchnorm.get_shape().as_list()[2] d3_split_2, _ = tf.split(d3_split_1, [shapeA, shapeB], axis=2, name='d3_split_2') tf_utils.print_activations(d3_split_2) d3_batchnorm = tf_utils.batch_norm(d3_split_2, name='d3_batchnorm', _ops=self._ops) d3_concat = tf.concat([d3_batchnorm, e3_batchnorm], axis=3, name='d3_concat') d3_relu = tf_utils.relu(d3_concat, name='d3_relu') # (19, 13) -> (38, 26) d4_deconv = tf_utils.deconv2d(d3_relu, self.gen_c[12], name='d4_deconv2d') # (38, 26) -> (38, 25) shapeA = e2_batchnorm.get_shape().as_list()[2] shapeB = d4_deconv.get_shape().as_list( )[2] - e2_batchnorm.get_shape().as_list()[2] d4_split, _ = tf.split(d4_deconv, [shapeA, shapeB], axis=2, name='d4_split') tf_utils.print_activations(d4_split) d4_batchnorm = tf_utils.batch_norm(d4_split, name='d4_batchnorm', _ops=self._ops) d4_concat = tf.concat([d4_batchnorm, e2_batchnorm], axis=3, name='d4_concat') d4_relu = tf_utils.relu(d4_concat, name='d4_relu') # (38, 25) -> (76, 50) d5_deconv = tf_utils.deconv2d(d4_relu, self.gen_c[13], name='d5_deconv2d') # (76, 50) -> (75, 50) shapeA = e1_batchnorm.get_shape().as_list()[1] shapeB = d5_deconv.get_shape().as_list( )[1] - e1_batchnorm.get_shape().as_list()[1] d5_split, _ = tf.split(d5_deconv, [shapeA, shapeB], axis=1, name='d5_split') tf_utils.print_activations(d5_split) d5_batchnorm = tf_utils.batch_norm(d5_split, name='d5_batchnorm', _ops=self._ops) d5_concat = tf.concat([d5_batchnorm, e1_batchnorm], axis=3, name='d5_concat') d5_relu = tf_utils.relu(d5_concat, name='d5_relu') # (75, 50) -> (150, 100) d6_deconv = tf_utils.deconv2d(d5_relu, self.gen_c[14], name='d6_deconv2d') d6_batchnorm = tf_utils.batch_norm(d6_deconv, name='d6_batchnorm', _ops=self._ops) d6_concat = tf.concat([d6_batchnorm, e0_conv2d], axis=3, name='d6_concat') d6_relu = tf_utils.relu(d6_concat, name='d6_relu') # (150, 100) -> (300, 200) d7_deconv = tf_utils.deconv2d(d6_relu, self.gen_c[15], name='d7_deconv2d') output = tf_utils.tanh(d7_deconv, name='output_tanh') # set reuse=True for next call self.reuse = True self.variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) return output
def __call__(self, x): with tf.variable_scope(self.name, reuse=self.reuse): tf_utils.print_activations(x) # conv: (N, H, W, C) -> (N, H/2, W/2, 64) output = tf_utils.conv2d(x, self.conv_dims[0], k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv0_conv2d') output = tf_utils.lrelu(output, name='conv0_lrelu', is_print=True) for idx, conv_dim in enumerate(self.conv_dims[1:]): # conv: (N, H/2, W/2, C) -> (N, H/4, W/4, 2C) output = tf_utils.conv2d(output, conv_dim, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv{}_conv2d'.format(idx + 1)) output = tf_utils.norm(output, _type=self.norm, _ops=self._ops, name='conv{}_norm'.format(idx + 1)) output = tf_utils.lrelu(output, name='conv{}_lrelu'.format(idx + 1), is_print=True) for idx, deconv_dim in enumerate(self.deconv_dims): # deconv: (N, H/16, W/16, C) -> (N, W/8, H/8, C/2) output = tf_utils.deconv2d(output, deconv_dim, k_h=4, k_w=4, name='deconv{}_conv2d'.format(idx)) output = tf_utils.norm(output, _type=self.norm, _ops=self._ops, name='deconv{}_norm'.format(idx)) output = tf_utils.relu(output, name='deconv{}_relu'.format(idx), is_print=True) # conv: (N, H/2, W/2, 64) -> (N, W, H, 3) output = tf_utils.deconv2d(output, self.output_channel, k_h=4, k_w=4, name='conv3_deconv2d') output = tf_utils.tanh(output, name='conv4_tanh', is_print=True) # set reuse=True for next call self.reuse = True self.variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) return output
def __call__(self, x, is_train=True): with tf.variable_scope(self.name, reuse=self.reuse): tf_utils.print_activations(x) # (N, 120, 160, 1) -> (N, 60, 80, 64) h0_conv = tf_utils.conv2d( x, output_dim=self.dims[0], initializer='he', name='h0_conv', logger=self.logger if is_train is True else None) h0_lrelu = tf_utils.lrelu( h0_conv, name='h0_lrelu', logger=self.logger if is_train is True else None) # (N, 60, 80, 64) -> (N, 30, 40, 128) h1_conv = tf_utils.conv2d( h0_lrelu, output_dim=self.dims[1], initializer='he', name='h1_conv', logger=self.logger if is_train is True else None) h1_norm = tf_utils.norm( h1_conv, name='h1_batch', _type='batch', _ops=self._ops, is_train=is_train, logger=self.logger if is_train is True else None) h1_lrelu = tf_utils.lrelu( h1_norm, name='h1_lrelu', logger=self.logger if is_train is True else None) # (N, 30, 40, 128) -> (N, 15, 20, 256) h2_conv = tf_utils.conv2d( h1_lrelu, output_dim=self.dims[2], initializer='he', name='h2_conv', logger=self.logger if is_train is True else None) h2_norm = tf_utils.norm( h2_conv, name='h2_batch', _type='batch', _ops=self._ops, is_train=is_train, logger=self.logger if is_train is True else None) h2_lrelu = tf_utils.lrelu( h2_norm, name='h2_lrelu', logger=self.logger if is_train is True else None) # (N, 15, 20, 256) -> (N, 8, 10, 512) h3_conv = tf_utils.conv2d( h2_lrelu, output_dim=self.dims[3], initializer='he', name='h3_conv', logger=self.logger if is_train is True else None) h3_norm = tf_utils.norm( h3_conv, name='h3_batch', _type='batch', _ops=self._ops, is_train=is_train, logger=self.logger if is_train is True else None) h3_lrelu = tf_utils.lrelu( h3_norm, name='h3_lrelu', logger=self.logger if is_train is True else None) # (N, 8, 10, 512) -> (N, 4, 5, 1024) h4_conv = tf_utils.conv2d( h3_lrelu, output_dim=self.dims[4], initializer='he', name='h4_conv', logger=self.logger if is_train is True else None) h4_norm = tf_utils.norm( h4_conv, name='h4_batch', _type='batch', _ops=self._ops, is_train=is_train, logger=self.logger if is_train is True else None) h4_lrelu = tf_utils.lrelu( h4_norm, name='h4_lrelu', logger=self.logger if is_train is True else None) # (N, 4, 5, 1024) -> (N, 4*5*1024) h4_flatten = tf_utils.flatten( h4_lrelu, name='h4_flatten', logger=self.logger if is_train is True else None) # (N, 4*5*1024) -> (N, 1) output = tf_utils.linear( h4_flatten, output_size=self.dims[5], initializer='he', name='output', logger=self.logger if is_train is True else None) # Set reuse=True for next call self.reuse = True self.variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) return output
def generator(self, data, name='g_'): with tf.variable_scope(name): data_flatten = flatten(data) tf_utils.print_activations(data_flatten) # from (N, 128) to (N, 2, 4, 512) h0_linear = tf_utils.linear(data_flatten, self.gen_c[0], name='h0_linear') h0_reshape = tf.reshape( h0_linear, [tf.shape(h0_linear)[0], 2, 4, int(self.gen_c[0] / (2 * 4))]) # (N, 4, 8, 512) resblock_1 = tf_utils.res_block_v2(h0_reshape, self.gen_c[1], filter_size=3, _ops=self.gen_train_ops, norm_='batch', resample='up', name='res_block_1') # (N, 8, 16, 256) resblock_2 = tf_utils.res_block_v2(resblock_1, self.gen_c[2], filter_size=3, _ops=self.gen_train_ops, norm_='batch', resample='up', name='res_block_2') # (N, 16, 32, 128) resblock_3 = tf_utils.res_block_v2(resblock_2, self.gen_c[3], filter_size=3, _ops=self.gen_train_ops, norm_='batch', resample='up', name='res_block_3') # (N, 32, 64, 64) resblock_4 = tf_utils.res_block_v2(resblock_3, self.gen_c[4], filter_size=3, _ops=self.gen_train_ops, norm_='batch', resample='up', name='res_block_4') # (N, 64, 128, 32) resblock_5 = tf_utils.res_block_v2(resblock_4, self.gen_c[5], filter_size=3, _ops=self.gen_train_ops, norm_='batch', resample='up', name='res_block_5') # (N, 128, 256, 32) resblock_6 = tf_utils.res_block_v2(resblock_5, self.gen_c[6], filter_size=3, _ops=self.gen_train_ops, norm_='batch', resample='up', name='res_block_6') norm_7 = tf_utils.norm(resblock_6, _type='batch', _ops=self.gen_train_ops, name='norm_7') relu_7 = tf_utils.relu(norm_7, name='relu_7') # (N, 128, 256, 3) output = tf_utils.conv2d(relu_7, output_dim=self.image_size[2], k_w=3, k_h=3, d_h=1, d_w=1, name='output') return tf_utils.tanh(output)
def forward_network(self, inputImg, padding='SAME', reuse=False): with tf.compat.v1.variable_scope(self.name, reuse=reuse): # This part is for compatible between input size [640, 400] and [320, 200] if self.resize_factor == 1.0: # Stage 0 tf_utils.print_activations(inputImg, logger=self.logger) s0_conv1 = tf_utils.conv2d(x=inputImg, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s0_conv1', logger=self.logger) s0_conv1 = tf_utils.relu(s0_conv1, name='relu_s0_conv1', logger=self.logger) s0_conv2 = tf_utils.conv2d(x=s0_conv1, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s0_conv2', logger=self.logger) if self.use_batch_norm: s0_conv2 = tf_utils.norm(s0_conv2, name='s0_norm1', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s0_conv2 = tf_utils.relu(s0_conv2, name='relu_s0_conv2', logger=self.logger) # Stage 1 s1_maxpool = tf_utils.max_pool(x=s0_conv2, name='s1_maxpool2d', logger=self.logger) s1_conv1 = tf_utils.conv2d(x=s1_maxpool, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s1_conv1', logger=self.logger) if self.use_batch_norm: s1_conv1 = tf_utils.norm(s1_conv1, name='s1_norm0', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s1_conv1 = tf_utils.relu(s1_conv1, name='relu_s1_conv1', logger=self.logger) s1_conv2 = tf_utils.conv2d(x=s1_conv1, output_dim=self.conv_dims[1], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s1_conv2', logger=self.logger) if self.use_batch_norm: s1_conv2 = tf_utils.norm(s1_conv2, name='s1_norm1', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s1_conv2 = tf_utils.relu(s1_conv2, name='relu_s1_conv2', logger=self.logger) else: # Stage 1 tf_utils.print_activations(inputImg, logger=self.logger) s1_conv1 = tf_utils.conv2d(x=inputImg, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s1_conv1', logger=self.logger) s1_conv1 = tf_utils.relu(s1_conv1, name='relu_s1_conv1', logger=self.logger) s1_conv2 = tf_utils.conv2d(x=s1_conv1, output_dim=self.conv_dims[1], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s1_conv2', logger=self.logger) if self.use_batch_norm: s1_conv2 = tf_utils.norm(s1_conv2, name='s1_norm1', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s1_conv2 = tf_utils.relu(s1_conv2, name='relu_s1_conv2', logger=self.logger) # Stage 2 s2_maxpool = tf_utils.max_pool(x=s1_conv2, name='s2_maxpool2d', logger=self.logger) s2_conv1 = tf_utils.conv2d(x=s2_maxpool, output_dim=self.conv_dims[2], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s2_conv1', logger=self.logger) if self.use_batch_norm: s2_conv1 = tf_utils.norm(s2_conv1, name='s2_norm0', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s2_conv1 = tf_utils.relu(s2_conv1, name='relu_s2_conv1', logger=self.logger) s2_conv2 = tf_utils.conv2d(x=s2_conv1, output_dim=self.conv_dims[3], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s2_conv2', logger=self.logger) if self.use_batch_norm: s2_conv2 = tf_utils.norm(s2_conv2, name='s2_norm1', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s2_conv2 = tf_utils.relu(s2_conv2, name='relu_s2_conv2', logger=self.logger) # Stage 3 s3_maxpool = tf_utils.max_pool(x=s2_conv2, name='s3_maxpool2d', logger=self.logger) s3_conv1 = tf_utils.conv2d(x=s3_maxpool, output_dim=self.conv_dims[4], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s3_conv1', logger=self.logger) if self.use_batch_norm: s3_conv1 = tf_utils.norm(s3_conv1, name='s3_norm0', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s3_conv1 = tf_utils.relu(s3_conv1, name='relu_s3_conv1', logger=self.logger) s3_conv2 = tf_utils.conv2d(x=s3_conv1, output_dim=self.conv_dims[5], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s3_conv2', logger=self.logger) if self.use_batch_norm: s3_conv2 = tf_utils.norm(s3_conv2, name='s3_norm1', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s3_conv2 = tf_utils.relu(s3_conv2, name='relu_s3_conv2', logger=self.logger) # Stage 4 s4_maxpool = tf_utils.max_pool(x=s3_conv2, name='s4_maxpool2d', logger=self.logger) s4_conv1 = tf_utils.conv2d(x=s4_maxpool, output_dim=self.conv_dims[6], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s4_conv1', logger=self.logger) if self.use_batch_norm: s4_conv1 = tf_utils.norm(s4_conv1, name='s4_norm0', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s4_conv1 = tf_utils.relu(s4_conv1, name='relu_s4_conv1', logger=self.logger) s4_conv2 = tf_utils.conv2d(x=s4_conv1, output_dim=self.conv_dims[7], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s4_conv2', logger=self.logger) if self.use_batch_norm: s4_conv2 = tf_utils.norm(s4_conv2, name='s4_norm1', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s4_conv2 = tf_utils.relu(s4_conv2, name='relu_s4_conv2', logger=self.logger) s4_conv2_drop = tf_utils.dropout(x=s4_conv2, keep_prob=self.ratePh, name='s4_dropout', logger=self.logger) # Stage 5 s5_maxpool = tf_utils.max_pool(x=s4_conv2_drop, name='s5_maxpool2d', logger=self.logger) s5_conv1 = tf_utils.conv2d(x=s5_maxpool, output_dim=self.conv_dims[8], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s5_conv1', logger=self.logger) if self.use_batch_norm: s5_conv1 = tf_utils.norm(s5_conv1, name='s5_norm0', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s5_conv1 = tf_utils.relu(s5_conv1, name='relu_s5_conv1', logger=self.logger) s5_conv2 = tf_utils.conv2d(x=s5_conv1, output_dim=self.conv_dims[9], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s5_conv2', logger=self.logger) if self.use_batch_norm: s5_conv2 = tf_utils.norm(s5_conv2, name='s5_norm1', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s5_conv2 = tf_utils.relu(s5_conv2, name='relu_s5_conv2', logger=self.logger) s5_conv2_drop = tf_utils.dropout(x=s5_conv2, keep_prob=self.ratePh, name='s5_dropout', logger=self.logger) # Stage 6 s6_deconv1 = tf_utils.deconv2d(x=s5_conv2_drop, output_dim=self.conv_dims[10], k_h=2, k_w=2, initializer='He', name='s6_deconv1', logger=self.logger) if self.use_batch_norm: s6_deconv1 = tf_utils.norm(s6_deconv1, name='s6_norm0', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s6_deconv1 = tf_utils.relu(s6_deconv1, name='relu_s6_deconv1', logger=self.logger) # Cropping w1 = s4_conv2_drop.get_shape().as_list()[2] w2 = s6_deconv1.get_shape().as_list()[2] - s4_conv2_drop.get_shape().as_list()[2] s6_deconv1_split, _ = tf.split(s6_deconv1, num_or_size_splits=[w1, w2], axis=2, name='axis2_split') tf_utils.print_activations(s6_deconv1_split, logger=self.logger) # Concat s6_concat = tf_utils.concat(values=[s6_deconv1_split, s4_conv2_drop], axis=3, name='s6_axis3_concat', logger=self.logger) s6_conv2 = tf_utils.conv2d(x=s6_concat, output_dim=self.conv_dims[11], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s6_conv2', logger=self.logger) if self.use_batch_norm: s6_conv2 = tf_utils.norm(s6_conv2, name='s6_norm1', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s6_conv2 = tf_utils.relu(s6_conv2, name='relu_s6_conv2', logger=self.logger) s6_conv3 = tf_utils.conv2d(x=s6_conv2, output_dim=self.conv_dims[12], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s6_conv3', logger=self.logger) if self.use_batch_norm: s6_conv3 = tf_utils.norm(s6_conv3, name='s6_norm2', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s6_conv3 = tf_utils.relu(s6_conv3, name='relu_s6_conv3', logger=self.logger) # Stage 7 s7_deconv1 = tf_utils.deconv2d(x=s6_conv3, output_dim=self.conv_dims[13], k_h=2, k_w=2, initializer='He', name='s7_deconv1', logger=self.logger) if self.use_batch_norm: s7_deconv1 = tf_utils.norm(s7_deconv1, name='s7_norm0', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s7_deconv1 = tf_utils.relu(s7_deconv1, name='relu_s7_deconv1', logger=self.logger) # Concat s7_concat = tf_utils.concat(values=[s7_deconv1, s3_conv2], axis=3, name='s7_axis3_concat', logger=self.logger) s7_conv2 = tf_utils.conv2d(x=s7_concat, output_dim=self.conv_dims[14], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s7_conv2', logger=self.logger) if self.use_batch_norm: s7_conv2 = tf_utils.norm(s7_conv2, name='s7_norm1', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s7_conv2 = tf_utils.relu(s7_conv2, name='relu_s7_conv2', logger=self.logger) s7_conv3 = tf_utils.conv2d(x=s7_conv2, output_dim=self.conv_dims[15], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s7_conv3', logger=self.logger) if self.use_batch_norm: s7_conv3 = tf_utils.norm(s7_conv3, name='s7_norm2', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s7_conv3 = tf_utils.relu(s7_conv3, name='relu_s7_conv3', logger=self.logger) # Stage 8 s8_deconv1 = tf_utils.deconv2d(x=s7_conv3, output_dim=self.conv_dims[16], k_h=2, k_w=2, initializer='He', name='s8_deconv1', logger=self.logger) if self.use_batch_norm: s8_deconv1 = tf_utils.norm(s8_deconv1, name='s8_norm0', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s8_deconv1 = tf_utils.relu(s8_deconv1, name='relu_s8_deconv1', logger=self.logger) # Concat s8_concat = tf_utils.concat(values=[s8_deconv1,s2_conv2], axis=3, name='s8_axis3_concat', logger=self.logger) s8_conv2 = tf_utils.conv2d(x=s8_concat, output_dim=self.conv_dims[17], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s8_conv2', logger=self.logger) if self.use_batch_norm: s8_conv2 = tf_utils.norm(s8_conv2, name='s8_norm1', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s8_conv2 = tf_utils.relu(s8_conv2, name='relu_s8_conv2', logger=self.logger) s8_conv3 = tf_utils.conv2d(x=s8_conv2, output_dim=self.conv_dims[18], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s8_conv3', logger=self.logger) if self.use_batch_norm: s8_conv3 = tf_utils.norm(s8_conv3, name='s8_norm2', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s8_conv3 = tf_utils.relu(s8_conv3, name='relu_conv3', logger=self.logger) # Stage 9 s9_deconv1 = tf_utils.deconv2d(x=s8_conv3, output_dim=self.conv_dims[19], k_h=2, k_w=2, initializer='He', name='s9_deconv1', logger=self.logger) if self.use_batch_norm: s9_deconv1 = tf_utils.norm(s9_deconv1, name='s9_norm0', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s9_deconv1 = tf_utils.relu(s9_deconv1, name='relu_s9_deconv1', logger=self.logger) # Concat s9_concat = tf_utils.concat(values=[s9_deconv1, s1_conv2], axis=3, name='s9_axis3_concat', logger=self.logger) s9_conv2 = tf_utils.conv2d(x=s9_concat, output_dim=self.conv_dims[20], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s9_conv2', logger=self.logger) if self.use_batch_norm: s9_conv2 = tf_utils.norm(s9_conv2, name='s9_norm1', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s9_conv2 = tf_utils.relu(s9_conv2, name='relu_s9_conv2', logger=self.logger) s9_conv3 = tf_utils.conv2d(x=s9_conv2, output_dim=self.conv_dims[21], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s9_conv3', logger=self.logger) if self.use_batch_norm: s9_conv3 = tf_utils.norm(s9_conv3, name='s9_norm2', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s9_conv3 = tf_utils.relu(s9_conv3, name='relu_s9_conv3', logger=self.logger) if self.resize_factor == 1.0: s10_deconv1 = tf_utils.deconv2d(x=s9_conv3, output_dim=self.conv_dims[-1], k_h=2, k_w=2, initializer='He', name='s10_deconv1', logger=self.logger) if self.use_batch_norm: s10_deconv1 = tf_utils.norm(s10_deconv1, name='s10_norm0', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s10_deconv1 = tf_utils.relu(s10_deconv1, name='relu_s10_deconv1', logger=self.logger) # Concat s10_concat = tf_utils.concat(values=[s10_deconv1, s0_conv2], axis=3, name='s10_axis3_concat', logger=self.logger) s10_conv2 = tf_utils.conv2d(s10_concat, output_dim=self.conv_dims[-1], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s10_conv2', logger=self.logger) if self.use_batch_norm: s10_conv2 = tf_utils.norm(s10_conv2, name='s10_norm1', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s10_conv2 = tf_utils.relu(s10_conv2, name='relu_s10_conv2', logger=self.logger) s10_conv3 = tf_utils.conv2d(x=s10_conv2, output_dim=self.conv_dims[-1], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s10_conv3', logger=self.logger) if self.use_batch_norm: s10_conv3 = tf_utils.norm(s10_conv3, name='s10_norm2', _type='batch', _ops=self._ops, is_train=self.trainMode, logger=self.logger) s10_conv3 = tf_utils.relu(s10_conv3, name='relu_s10_conv3', logger=self.logger) output = tf_utils.conv2d(s10_conv3, output_dim=self.numClasses, k_h=1, k_w=1, d_h=1, d_w=1, padding=padding, initializer='He', name='output', logger=self.logger) else: output = tf_utils.conv2d(s9_conv3, output_dim=self.numClasses, k_h=1, k_w=1, d_h=1, d_w=1, padding=padding, initializer='He', name='output', logger=self.logger) return output
def __call__(self, x, keep_rate=0.5): with tf.compat.v1.variable_scope(self.name, reuse=self.reuse): tf_utils.print_activations(x, logger=self.logger) # E0: (320, 200) -> (160, 100) e0_conv2d = tf_utils.conv2d(x, output_dim=self.gen_c[0], initializer='He', logger=self.logger, name='e0_conv2d') e0_lrelu = tf_utils.lrelu(e0_conv2d, logger=self.logger, name='e0_lrelu') # E1: (160, 100) -> (80, 50) e1_conv2d = tf_utils.conv2d(e0_lrelu, output_dim=self.gen_c[1], initializer='He', logger=self.logger, name='e1_conv2d') e1_batchnorm = tf_utils.norm(e1_conv2d, _type=self.norm, _ops=self._ops, logger=self.logger, name='e1_norm') e1_lrelu = tf_utils.lrelu(e1_batchnorm, logger=self.logger, name='e1_lrelu') # E2: (80, 50) -> (40, 25) e2_conv2d = tf_utils.conv2d(e1_lrelu, output_dim=self.gen_c[2], initializer='He', logger=self.logger, name='e2_conv2d') e2_batchnorm = tf_utils.norm(e2_conv2d, _type=self.norm, _ops=self._ops, logger=self.logger, name='e2_norm') e2_lrelu = tf_utils.lrelu(e2_batchnorm, logger=self.logger, name='e2_lrelu') # E3: (40, 25) -> (20, 13) e3_conv2d = tf_utils.conv2d(e2_lrelu, output_dim=self.gen_c[3], initializer='He', logger=self.logger, name='e3_conv2d') e3_batchnorm = tf_utils.norm(e3_conv2d, _type=self.norm, _ops=self._ops, logger=self.logger, name='e3_norm') e3_lrelu = tf_utils.lrelu(e3_batchnorm, logger=self.logger, name='e3_lrelu') # E4: (20, 13) -> (10, 7) e4_conv2d = tf_utils.conv2d(e3_lrelu, output_dim=self.gen_c[4], initializer='He', logger=self.logger, name='e4_conv2d') e4_batchnorm = tf_utils.norm(e4_conv2d, _type=self.norm, _ops=self._ops, logger=self.logger, name='e4_norm') e4_lrelu = tf_utils.lrelu(e4_batchnorm, logger=self.logger, name='e4_lrelu') # E5: (10, 7) -> (5, 4) e5_conv2d = tf_utils.conv2d(e4_lrelu, output_dim=self.gen_c[5], initializer='He', logger=self.logger, name='e5_conv2d') e5_batchnorm = tf_utils.norm(e5_conv2d, _type=self.norm, _ops=self._ops, logger=self.logger, name='e5_norm') e5_lrelu = tf_utils.lrelu(e5_batchnorm, logger=self.logger, name='e5_lrelu') # E6: (5, 4) -> (3, 2) e6_conv2d = tf_utils.conv2d(e5_lrelu, output_dim=self.gen_c[6], initializer='He', logger=self.logger, name='e6_conv2d') e6_batchnorm = tf_utils.norm(e6_conv2d, _type=self.norm, _ops=self._ops, logger=self.logger, name='e6_norm') e6_lrelu = tf_utils.lrelu(e6_batchnorm, logger=self.logger, name='e6_lrelu') # E7: (3, 2) -> (2, 1) e7_conv2d = tf_utils.conv2d(e6_lrelu, output_dim=self.gen_c[7], initializer='He', logger=self.logger, name='e7_conv2d') e7_batchnorm = tf_utils.norm(e7_conv2d, _type=self.norm, _ops=self._ops, logger=self.logger, name='e7_norm') e7_relu = tf_utils.lrelu(e7_batchnorm, logger=self.logger, name='e7_relu') # D0: (2, 1) -> (3, 2) # Stage1: (2, 1) -> (4, 2) d0_deconv = tf_utils.deconv2d(e7_relu, output_dim=self.gen_c[8], initializer='He', logger=self.logger, name='d0_deconv2d') # Stage2: (4, 2) -> (3, 2) shapeA = e6_conv2d.get_shape().as_list()[1] shapeB = d0_deconv.get_shape().as_list()[1] - e6_conv2d.get_shape( ).as_list()[1] d0_split, _ = tf.split(d0_deconv, [shapeA, shapeB], axis=1, name='d0_split') tf_utils.print_activations(d0_split, logger=self.logger) # Stage3: Batch norm, concatenation, and relu d0_batchnorm = tf_utils.norm(d0_split, _type=self.norm, _ops=self._ops, logger=self.logger, name='d0_norm') d0_drop = tf_utils.dropout(d0_batchnorm, keep_prob=keep_rate, logger=self.logger, name='d0_dropout') d0_concat = tf.concat([d0_drop, e6_batchnorm], axis=3, name='d0_concat') d0_relu = tf_utils.relu(d0_concat, logger=self.logger, name='d0_relu') # D1: (3, 2) -> (5, 4) # Stage1: (3, 2) -> (6, 4) d1_deconv = tf_utils.deconv2d(d0_relu, output_dim=self.gen_c[9], initializer='He', logger=self.logger, name='d1_deconv2d') # Stage2: (6, 4) -> (5, 4) shapeA = e5_batchnorm.get_shape().as_list()[1] shapeB = d1_deconv.get_shape().as_list( )[1] - e5_batchnorm.get_shape().as_list()[1] d1_split, _ = tf.split(d1_deconv, [shapeA, shapeB], axis=1, name='d1_split') tf_utils.print_activations(d1_split, logger=self.logger) # Stage3: Batch norm, concatenation, and relu d1_batchnorm = tf_utils.norm(d1_split, _type=self.norm, _ops=self._ops, logger=self.logger, name='d1_norm') d1_drop = tf_utils.dropout(d1_batchnorm, keep_prob=keep_rate, logger=self.logger, name='d1_dropout') d1_concat = tf.concat([d1_drop, e5_batchnorm], axis=3, name='d1_concat') d1_relu = tf_utils.relu(d1_concat, logger=self.logger, name='d1_relu') # D2: (5, 4) -> (10, 7) # Stage1: (5, 4) -> (10, 8) d2_deconv = tf_utils.deconv2d(d1_relu, output_dim=self.gen_c[10], initializer='He', logger=self.logger, name='d2_deconv2d') # Stage2: (10, 8) -> (10, 7) shapeA = e4_batchnorm.get_shape().as_list()[2] shapeB = d2_deconv.get_shape().as_list( )[2] - e4_batchnorm.get_shape().as_list()[2] d2_split, _ = tf.split(d2_deconv, [shapeA, shapeB], axis=2, name='d2_split') tf_utils.print_activations(d2_split, logger=self.logger) # Stage3: Batch norm, concatenation, and relu d2_batchnorm = tf_utils.norm(d2_split, _type=self.norm, _ops=self._ops, logger=self.logger, name='d2_norm') d2_drop = tf_utils.dropout(d2_batchnorm, keep_prob=keep_rate, logger=self.logger, name='d2_dropout') d2_concat = tf.concat([d2_drop, e4_batchnorm], axis=3, name='d2_concat') d2_relu = tf_utils.relu(d2_concat, logger=self.logger, name='d2_relu') # D3: (10, 7) -> (20, 13) # Stage1: (10, 7) -> (20, 14) d3_deconv = tf_utils.deconv2d(d2_relu, output_dim=self.gen_c[11], initializer='He', logger=self.logger, name='d3_deconv2d') # Stage2: (20, 14) -> (20, 13) shapeA = e3_batchnorm.get_shape().as_list()[2] shapeB = d3_deconv.get_shape().as_list( )[2] - e3_batchnorm.get_shape().as_list()[2] d3_split, _ = tf.split(d3_deconv, [shapeA, shapeB], axis=2, name='d3_split_2') tf_utils.print_activations(d3_split, logger=self.logger) # Stage3: Batch norm, concatenation, and relu d3_batchnorm = tf_utils.norm(d3_split, _type=self.norm, _ops=self._ops, logger=self.logger, name='d3_norm') d3_concat = tf.concat([d3_batchnorm, e3_batchnorm], axis=3, name='d3_concat') d3_relu = tf_utils.relu(d3_concat, logger=self.logger, name='d3_relu') # D4: (20, 13) -> (40, 25) # Stage1: (20, 13) -> (40, 26) d4_deconv = tf_utils.deconv2d(d3_relu, output_dim=self.gen_c[12], initializer='He', logger=self.logger, name='d4_deconv2d') # Stage2: (40, 26) -> (40, 25) shapeA = e2_batchnorm.get_shape().as_list()[2] shapeB = d4_deconv.get_shape().as_list( )[2] - e2_batchnorm.get_shape().as_list()[2] d4_split, _ = tf.split(d4_deconv, [shapeA, shapeB], axis=2, name='d4_split') tf_utils.print_activations(d4_split, logger=self.logger) # Stage3: Batch norm, concatenation, and relu d4_batchnorm = tf_utils.norm(d4_split, _type=self.norm, _ops=self._ops, logger=self.logger, name='d4_norm') d4_concat = tf.concat([d4_batchnorm, e2_batchnorm], axis=3, name='d4_concat') d4_relu = tf_utils.relu(d4_concat, logger=self.logger, name='d4_relu') # D5: (40, 25, 256) -> (80, 50, 128) d5_deconv = tf_utils.deconv2d(d4_relu, output_dim=self.gen_c[13], initializer='He', logger=self.logger, name='d5_deconv2d') d5_batchnorm = tf_utils.norm(d5_deconv, _type=self.norm, _ops=self._ops, logger=self.logger, name='d5_norm') d5_concat = tf.concat([d5_batchnorm, e1_batchnorm], axis=3, name='d5_concat') d5_relu = tf_utils.relu(d5_concat, logger=self.logger, name='d5_relu') # D6: (80, 50, 128) -> (160, 100, 64) d6_deconv = tf_utils.deconv2d(d5_relu, output_dim=self.gen_c[14], initializer='He', logger=self.logger, name='d6_deconv2d') d6_batchnorm = tf_utils.norm(d6_deconv, _type=self.norm, _ops=self._ops, logger=self.logger, name='d6_norm') d6_concat = tf.concat([d6_batchnorm, e0_conv2d], axis=3, name='d6_concat') d6_relu = tf_utils.relu(d6_concat, logger=self.logger, name='d6_relu') # D7: (160, 100, 64) -> (320, 200, 1) d7_deconv = tf_utils.deconv2d(d6_relu, output_dim=self.gen_c[15], initializer='He', logger=self.logger, name='d7_deconv2d') output = tf_utils.tanh(d7_deconv, logger=self.logger, name='output_tanh') # Set reuse=True for next call self.reuse = True self.variables = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) return output
def encoder(self, data, name='encoder'): with tf.variable_scope(name): data_flatten = flatten(data) tf_utils.print_activations(data_flatten) # 1st hidden layer h0_linear = tf_utils.linear(data_flatten, self.n_hidden, name='h0_linear') h0_elu = tf_utils.elu(h0_linear, name='h0_elu') h0_drop = tf.nn.dropout(h0_elu, keep_prob=self.keep_prob_tfph, name='h0_drop') tf_utils.print_activations(h0_drop) # 2nd hidden layer h1_linear = tf_utils.linear(h0_drop, self.n_hidden, name='h1_linear') h1_tanh = tf_utils.tanh(h1_linear, name='h1_tanh') h1_drop = tf.nn.dropout(h1_tanh, keep_prob=self.keep_prob_tfph, name='h1_drop') tf_utils.print_activations(h1_drop) # 3rd hidden layer h2_linear = tf_utils.linear(h1_drop, 2 * self.flags.z_dim, name='h2_linear') tf_utils.print_activations(h2_linear) # The mean parameter is unconstrained mean = h2_linear[:, :self.flags.z_dim] # The standard deviation must be positive. # Parameterize with a softplus and add a small epsilon for numerical stability stddev = 1e-6 + tf.nn.softplus(h2_linear[:, self.flags.z_dim:]) tf_utils.print_activations(mean) tf_utils.print_activations(stddev) return mean, stddev
def __call__(self, x, is_train=True): with tf.variable_scope(self.name, reuse=self.reuse): tf_utils.print_activations(x) # (N, 100) -> (N, 4, 5, 512) h0_linear = tf_utils.linear( x, 4 * 5 * self.dims[0], name='h0_linear', initializer='He', logger=self.logger if is_train is True else None) h0_reshape = tf.reshape( h0_linear, [tf.shape(h0_linear)[0], 4, 5, self.dims[0]]) # (N, 4, 5, 512) -> (N, 8, 10, 512) resblock_1 = tf_utils.res_block_v2( x=h0_reshape, k=self.dims[1], filter_size=3, _ops=self._ops, norm_='batch', resample='up', name='res_block_1', logger=self.logger if is_train is True else None) # (N, 8, 10, 512) -> (N, 16, 20, 256) resblock_2 = tf_utils.res_block_v2( x=resblock_1, k=self.dims[2], filter_size=3, _ops=self._ops, norm_='batch', resample='up', name='res_block_2', logger=self.logger if is_train is True else None) # (N, 16, 20, 256) -> (N, 15, 20, 256) resblock_2_split, _ = tf.split(resblock_2, [15, 1], axis=1, name='resblock_2_split') tf_utils.print_activations( resblock_2_split, logger=self.logger if is_train is True else None) # (N, 15, 20, 256) -> (N, 30, 40, 128) resblock_3 = tf_utils.res_block_v2( x=resblock_2_split, k=self.dims[3], filter_size=3, _ops=self._ops, norm_='batch', resample='up', name='res_block_3', logger=self.logger if is_train is True else None) # (N, 30, 40, 128) -> (N, 60, 80, 64) resblock_4 = tf_utils.res_block_v2( x=resblock_3, k=self.dims[4], filter_size=3, _ops=self._ops, norm_='batch', resample='up', name='res_block_4', logger=self.logger if is_train is True else None) # (N, 60, 80, 64) -> (N, 120, 160, 64) resblock_5 = tf_utils.res_block_v2( x=resblock_4, k=self.dims[5], filter_size=3, _ops=self._ops, norm_='batch', resample='up', name='res_block_5', logger=self.logger if is_train is True else None) norm_5 = tf_utils.norm( resblock_5, name='norm_5', _type='batch', _ops=self._ops, is_train=is_train, logger=self.logger if is_train is True else None) relu_5 = tf_utils.relu( norm_5, name='relu_5', logger=self.logger if is_train is True else None) # (N, 120, 160, 64) -> (N, 120, 160, 3) conv_6 = tf_utils.conv2d( relu_5, output_dim=self.dims[6], k_h=3, k_w=3, d_h=1, d_w=1, name='conv_6', logger=self.logger if is_train is True else None) output = tf_utils.tanh( conv_6, name='output', logger=self.logger if is_train is True else None) # Set reuse=True for next call self.reuse = True self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) return output
def __call__(self, x, is_train=True): with tf.variable_scope(self.name, reuse=self.reuse): tf_utils.print_activations(x) # (N, 100) -> (N, 4, 5, 1024) h0_linear = tf_utils.linear( x, 4 * 5 * self.dims[0], name='h0_linear', initializer='He', logger=self.logger if is_train is True else None) h0_reshape = tf.reshape( h0_linear, [tf.shape(h0_linear)[0], 4, 5, self.dims[0]]) h0_norm = tf_utils.norm( h0_reshape, name='h0_batch', _type='batch', _ops=self._ops, is_train=is_train, logger=self.logger if is_train is True else None) h0_relu = tf_utils.relu( h0_norm, name='h0_relu', logger=self.logger if is_train is True else None) # (N, 4, 5, 1024) -> (N, 8, 10, 512) h1_deconv = tf_utils.deconv2d( h0_relu, output_dim=self.dims[1], name='h1_deconv2d', initializer='He', logger=self.logger if is_train is True else None) h1_norm = tf_utils.norm( h1_deconv, name='h1_batch', _type='batch', _ops=self._ops, is_train=is_train, logger=self.logger if is_train is True else None) h1_relu = tf_utils.relu( h1_norm, name='h1_relu', logger=self.logger if is_train is True else None) # (N, 8, 10, 512) -> (N, 16, 20, 256) h2_deconv = tf_utils.deconv2d( h1_relu, output_dim=self.dims[2], name='h2_deconv2d', initializer='He', logger=self.logger if is_train is True else None) h2_norm = tf_utils.norm( h2_deconv, name='h2_batch', _type='batch', _ops=self._ops, is_train=is_train, logger=self.logger if is_train is True else None) h2_relu = tf_utils.relu( h2_norm, name='h2_relu', logger=self.logger if is_train is True else None) # (N, 16, 20, 256) -> (N, 15, 20, 256) h2_split, _ = tf.split(h2_relu, [15, 1], axis=1, name='h2_split') tf_utils.print_activations( h2_split, logger=self.logger if is_train is True else None) # (N, 15, 20, 256) -> (N, 30, 40, 128) h3_deconv = tf_utils.deconv2d( h2_split, output_dim=self.dims[3], name='h3_deconv2d', initializer='He', logger=self.logger if is_train is True else None) h3_norm = tf_utils.norm( h3_deconv, name='h3_batch', _type='batch', _ops=self._ops, is_train=is_train, logger=self.logger if is_train is True else None) h3_relu = tf_utils.relu( h3_norm, name='h3_relu', logger=self.logger if is_train is True else None) # (N, 30, 40, 128) -> (N, 60, 80, 64) h4_deconv = tf_utils.deconv2d( h3_relu, output_dim=self.dims[4], name='h4_deconv2d', initializer='He', logger=self.logger if is_train is True else None) h4_norm = tf_utils.norm( h4_deconv, name='h4_batch', _type='batch', _ops=self._ops, is_train=is_train, logger=self.logger if is_train is True else None) h4_relu = tf_utils.relu( h4_norm, name='h4_relu', logger=self.logger if is_train is True else None) # (N, 60, 80, 64) -> (N, 120, 160, 1) h5_deconv = tf_utils.deconv2d( h4_relu, output_dim=self.dims[5], name='h5_deconv', initializer='He', logger=self.logger if is_train is True else None) output = tf_utils.tanh( h5_deconv, name='output', logger=self.logger if is_train is True else None) # Set reuse=True for next call self.reuse = True self.variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) return output
def u_net(self): # Stage 1 tf_utils.print_activations(self.inp_img, logger=self.logger) s1_conv1 = tf_utils.conv2d(x=self.inp_img, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s1_conv1', logger=self.logger) s1_conv1 = tf_utils.relu(s1_conv1, name='relu_s1_conv1', logger=self.logger) s1_conv2 = tf_utils.conv2d(x=s1_conv1, output_dim=self.conv_dims[1], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s1_conv2', logger=self.logger) s1_conv2 = tf_utils.relu(s1_conv2, name='relu_s1_conv2', logger=self.logger) # Stage 2 s2_maxpool = tf_utils.max_pool(x=s1_conv2, name='s2_maxpool', logger=self.logger) s2_conv1 = tf_utils.conv2d(x=s2_maxpool, output_dim=self.conv_dims[2], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s2_conv1', logger=self.logger) s2_conv1 = tf_utils.relu(s2_conv1, name='relu_s2_conv1', logger=self.logger) s2_conv2 = tf_utils.conv2d(x=s2_conv1, output_dim=self.conv_dims[3], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s2_conv2', logger=self.logger) s2_conv2 = tf_utils.relu(s2_conv2, name='relu_s2_conv2', logger=self.logger) # Stage 3 s3_maxpool = tf_utils.max_pool(x=s2_conv2, name='s3_maxpool', logger=self.logger) s3_conv1 = tf_utils.conv2d(x=s3_maxpool, output_dim=self.conv_dims[4], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s3_conv1', logger=self.logger) s3_conv1 = tf_utils.relu(s3_conv1, name='relu_s3_conv1', logger=self.logger) s3_conv2 = tf_utils.conv2d(x=s3_conv1, output_dim=self.conv_dims[5], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s3_conv2', logger=self.logger) s3_conv2 = tf_utils.relu(s3_conv2, name='relu_s3_conv2', logger=self.logger) # Stage 4 s4_maxpool = tf_utils.max_pool(x=s3_conv2, name='s4_maxpool', logger=self.logger) s4_conv1 = tf_utils.conv2d(x=s4_maxpool, output_dim=self.conv_dims[6], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s4_conv1', logger=self.logger) s4_conv1 = tf_utils.relu(s4_conv1, name='relu_s4_conv1', logger=self.logger) s4_conv2 = tf_utils.conv2d(x=s4_conv1, output_dim=self.conv_dims[7], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s4_conv2', logger=self.logger) s4_conv2 = tf_utils.relu(s4_conv2, name='relu_s4_conv2', logger=self.logger) s4_conv2_drop = tf_utils.dropout(x=s4_conv2, keep_prob=self.keep_prob, name='s4_conv2_dropout', logger=self.logger) # Stage 5 s5_maxpool = tf_utils.max_pool(x=s4_conv2_drop, name='s5_maxpool', logger=self.logger) s5_conv1 = tf_utils.conv2d(x=s5_maxpool, output_dim=self.conv_dims[8], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s5_conv1', logger=self.logger) s5_conv1 = tf_utils.relu(s5_conv1, name='relu_s5_conv1', logger=self.logger) s5_conv2 = tf_utils.conv2d(x=s5_conv1, output_dim=self.conv_dims[9], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s5_conv2', logger=self.logger) s5_conv2 = tf_utils.relu(s5_conv2, name='relu_s5_conv2', logger=self.logger) s5_conv2_drop = tf_utils.dropout(x=s5_conv2, keep_prob=self.keep_prob, name='s5_conv2_dropout', logger=self.logger) # Stage 6 s6_deconv1 = tf_utils.deconv2d(x=s5_conv2_drop, output_dim=self.conv_dims[10], k_h=2, k_w=2, initializer='He', name='s6_deconv1', logger=self.logger) s6_deconv1 = tf_utils.relu(s6_deconv1, name='relu_s6_deconv1', logger=self.logger) # Cropping h1, w1 = s4_conv2_drop.get_shape().as_list()[1:3] h2, w2 = s6_deconv1.get_shape().as_list()[1:3] s4_conv2_crop = tf.image.crop_to_bounding_box( image=s4_conv2_drop, offset_height=int(0.5 * (h1 - h2)), offset_width=int(0.5 * (w1 - w2)), target_height=h2, target_width=w2) tf_utils.print_activations(s4_conv2_crop, logger=self.logger) s6_concat = tf_utils.concat(values=[s4_conv2_crop, s6_deconv1], axis=3, name='s6_concat', logger=self.logger) s6_conv2 = tf_utils.conv2d(x=s6_concat, output_dim=self.conv_dims[11], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s6_conv2', logger=self.logger) s6_conv2 = tf_utils.relu(s6_conv2, name='relu_s6_conv2', logger=self.logger) s6_conv3 = tf_utils.conv2d(x=s6_conv2, output_dim=self.conv_dims[12], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s6_conv3', logger=self.logger) s6_conv3 = tf_utils.relu(s6_conv3, name='relu_s6_conv3', logger=self.logger) # Stage 7 s7_deconv1 = tf_utils.deconv2d(x=s6_conv3, output_dim=self.conv_dims[13], k_h=2, k_w=2, initializer='He', name='s7_deconv1', logger=self.logger) s7_deconv1 = tf_utils.relu(s7_deconv1, name='relu_s7_deconv1', logger=self.logger) # Cropping h1, w1 = s3_conv2.get_shape().as_list()[1:3] h2, w2 = s7_deconv1.get_shape().as_list()[1:3] s3_conv2_crop = tf.image.crop_to_bounding_box( image=s3_conv2, offset_height=int(0.5 * (h1 - h2)), offset_width=int(0.5 * (w1 - w2)), target_height=h2, target_width=w2) tf_utils.print_activations(s3_conv2_crop, logger=self.logger) s7_concat = tf_utils.concat(values=[s3_conv2_crop, s7_deconv1], axis=3, name='s7_concat', logger=self.logger) s7_conv2 = tf_utils.conv2d(x=s7_concat, output_dim=self.conv_dims[14], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s7_conv2', logger=self.logger) s7_conv2 = tf_utils.relu(s7_conv2, name='relu_s7_conv2', logger=self.logger) s7_conv3 = tf_utils.conv2d(x=s7_conv2, output_dim=self.conv_dims[15], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s7_conv3', logger=self.logger) s7_conv3 = tf_utils.relu(s7_conv3, name='relu_s7_conv3', logger=self.logger) # Stage 8 s8_deconv1 = tf_utils.deconv2d(x=s7_conv3, output_dim=self.conv_dims[16], k_h=2, k_w=2, initializer='He', name='s8_deconv1', logger=self.logger) s8_deconv1 = tf_utils.relu(s8_deconv1, name='relu_s8_deconv1', logger=self.logger) # Cropping h1, w1 = s2_conv2.get_shape().as_list()[1:3] h2, w2 = s8_deconv1.get_shape().as_list()[1:3] s2_conv2_crop = tf.image.crop_to_bounding_box( image=s2_conv2, offset_height=int(0.5 * (h1 - h2)), offset_width=int(0.5 * (w1 - w2)), target_height=h2, target_width=w2) tf_utils.print_activations(s2_conv2_crop, logger=self.logger) s8_concat = tf_utils.concat(values=[s2_conv2_crop, s8_deconv1], axis=3, name='s8_concat', logger=self.logger) s8_conv2 = tf_utils.conv2d(x=s8_concat, output_dim=self.conv_dims[17], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s8_conv2', logger=self.logger) s8_conv2 = tf_utils.relu(s8_conv2, name='relu_s8_conv2', logger=self.logger) s8_conv3 = tf_utils.conv2d(x=s8_conv2, output_dim=self.conv_dims[18], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s8_conv3', logger=self.logger) s8_conv3 = tf_utils.relu(s8_conv3, name='relu_conv3', logger=self.logger) # Stage 9 s9_deconv1 = tf_utils.deconv2d(x=s8_conv3, output_dim=self.conv_dims[19], k_h=2, k_w=2, initializer='He', name='s9_deconv1', logger=self.logger) s9_deconv1 = tf_utils.relu(s9_deconv1, name='relu_s9_deconv1', logger=self.logger) # Cropping h1, w1 = s1_conv2.get_shape().as_list()[1:3] h2, w2 = s9_deconv1.get_shape().as_list()[1:3] s1_conv2_crop = tf.image.crop_to_bounding_box( image=s1_conv2, offset_height=int(0.5 * (h1 - h2)), offset_width=int(0.5 * (w1 - w2)), target_height=h2, target_width=w2) tf_utils.print_activations(s1_conv2_crop, logger=self.logger) s9_concat = tf_utils.concat(values=[s1_conv2_crop, s9_deconv1], axis=3, name='s9_concat', logger=self.logger) s9_conv2 = tf_utils.conv2d(x=s9_concat, output_dim=self.conv_dims[20], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s9_conv2', logger=self.logger) s9_conv2 = tf_utils.relu(s9_conv2, name='relu_s9_conv2', logger=self.logger) s9_conv3 = tf_utils.conv2d(x=s9_conv2, output_dim=self.conv_dims[21], k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', initializer='He', name='s9_conv3', logger=self.logger) s9_conv3 = tf_utils.relu(s9_conv3, name='relu_s9_conv3', logger=self.logger) self.pred = tf_utils.conv2d(x=s9_conv3, output_dim=self.conv_dims[22], k_h=1, k_w=1, d_h=1, d_w=1, padding='SAME', initializer='He', name='output', logger=self.logger)
def forward_network(self, img, padding='SAME', reuse=False): with tf.compat.v1.variable_scope(self.name, reuse=reuse): # Stage 0 s0_conv1 = tf_utils.conv2d(x=img, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s0_conv1') s0_conv1 = tf_utils.relu(s0_conv1, name='relu_s0_conv1') s0_conv2 = tf_utils.conv2d(x=s0_conv1, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s0_conv2') s0_conv2 = tf_utils.norm(s0_conv2, name='s0_norm1', _type='batch', _ops=self._ops, is_train=False) s0_conv2 = tf_utils.relu(s0_conv2, name='relu_s0_conv2') # Stage 1 s1_maxpool = tf_utils.max_pool(x=s0_conv2, name='s1_maxpool2d') s1_conv1 = tf_utils.conv2d(x=s1_maxpool, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s1_conv1') s1_conv1 = tf_utils.norm(s1_conv1, name='s1_norm0', _type='batch', _ops=self._ops, is_train=False) s1_conv1 = tf_utils.relu(s1_conv1, name='relu_s1_conv1') s1_conv2 = tf_utils.conv2d(x=s1_conv1, output_dim=self.conv_dims[1], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s1_conv2') s1_conv2 = tf_utils.norm(s1_conv2, name='s1_norm1', _type='batch', _ops=self._ops, is_train=False) s1_conv2 = tf_utils.relu(s1_conv2, name='relu_s1_conv2') # Stage 2 s2_maxpool = tf_utils.max_pool(x=s1_conv2, name='s2_maxpool2d') s2_conv1 = tf_utils.conv2d(x=s2_maxpool, output_dim=self.conv_dims[2], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s2_conv1') s2_conv1 = tf_utils.norm(s2_conv1, name='s2_norm0', _type='batch', _ops=self._ops, is_train=False) s2_conv1 = tf_utils.relu(s2_conv1, name='relu_s2_conv1') s2_conv2 = tf_utils.conv2d(x=s2_conv1, output_dim=self.conv_dims[3], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s2_conv2') s2_conv2 = tf_utils.norm(s2_conv2, name='s2_norm1', _type='batch', _ops=self._ops, is_train=False) s2_conv2 = tf_utils.relu(s2_conv2, name='relu_s2_conv2') # Stage 3 s3_maxpool = tf_utils.max_pool(x=s2_conv2, name='s3_maxpool2d') s3_conv1 = tf_utils.conv2d(x=s3_maxpool, output_dim=self.conv_dims[4], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s3_conv1') s3_conv1 = tf_utils.norm(s3_conv1, name='s3_norm0', _type='batch', _ops=self._ops, is_train=False) s3_conv1 = tf_utils.relu(s3_conv1, name='relu_s3_conv1') s3_conv2 = tf_utils.conv2d(x=s3_conv1, output_dim=self.conv_dims[5], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s3_conv2') s3_conv2 = tf_utils.norm(s3_conv2, name='s3_norm1', _type='batch', _ops=self._ops, is_train=False) s3_conv2 = tf_utils.relu(s3_conv2, name='relu_s3_conv2') # Stage 4 s4_maxpool = tf_utils.max_pool(x=s3_conv2, name='s4_maxpool2d') s4_conv1 = tf_utils.conv2d(x=s4_maxpool, output_dim=self.conv_dims[6], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s4_conv1') s4_conv1 = tf_utils.norm(s4_conv1, name='s4_norm0', _type='batch', _ops=self._ops, is_train=False) s4_conv1 = tf_utils.relu(s4_conv1, name='relu_s4_conv1') s4_conv2 = tf_utils.conv2d(x=s4_conv1, output_dim=self.conv_dims[7], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s4_conv2') s4_conv2 = tf_utils.norm(s4_conv2, name='s4_norm1', _type='batch', _ops=self._ops, is_train=False) s4_conv2 = tf_utils.relu(s4_conv2, name='relu_s4_conv2') s4_conv2_drop = tf_utils.dropout(x=s4_conv2, keep_prob=0., name='s4_dropout') # Stage 5 s5_maxpool = tf_utils.max_pool(x=s4_conv2_drop, name='s5_maxpool2d') s5_conv1 = tf_utils.conv2d(x=s5_maxpool, output_dim=self.conv_dims[8], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s5_conv1') s5_conv1 = tf_utils.norm(s5_conv1, name='s5_norm0', _type='batch', _ops=self._ops, is_train=False) s5_conv1 = tf_utils.relu(s5_conv1, name='relu_s5_conv1') s5_conv2 = tf_utils.conv2d(x=s5_conv1, output_dim=self.conv_dims[9], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s5_conv2') s5_conv2 = tf_utils.norm(s5_conv2, name='s5_norm1', _type='batch', _ops=self._ops, is_train=False) s5_conv2 = tf_utils.relu(s5_conv2, name='relu_s5_conv2') s5_conv2_drop = tf_utils.dropout(x=s5_conv2, keep_prob=0., name='s5_dropout') # Stage 6 s6_deconv1 = tf_utils.deconv2d(x=s5_conv2_drop, output_dim=self.conv_dims[10], k_h=2, k_w=2, initializer='He', name='s6_deconv1') s6_deconv1 = tf_utils.norm(s6_deconv1, name='s6_norm0', _type='batch', _ops=self._ops, is_train=False) s6_deconv1 = tf_utils.relu(s6_deconv1, name='relu_s6_deconv1') # Cropping w1 = s4_conv2_drop.get_shape().as_list()[2] w2 = s6_deconv1.get_shape().as_list()[2] - s4_conv2_drop.get_shape( ).as_list()[2] s6_deconv1_split, _ = tf.split(s6_deconv1, num_or_size_splits=[w1, w2], axis=2, name='axis2_split') tf_utils.print_activations(s6_deconv1_split) # Concat s6_concat = tf_utils.concat( values=[s6_deconv1_split, s4_conv2_drop], axis=3, name='s6_axis3_concat') s6_conv2 = tf_utils.conv2d(x=s6_concat, output_dim=self.conv_dims[11], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s6_conv2') s6_conv2 = tf_utils.norm(s6_conv2, name='s6_norm1', _type='batch', _ops=self._ops, is_train=False) s6_conv2 = tf_utils.relu(s6_conv2, name='relu_s6_conv2') s6_conv3 = tf_utils.conv2d(x=s6_conv2, output_dim=self.conv_dims[12], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s6_conv3') s6_conv3 = tf_utils.norm(s6_conv3, name='s6_norm2', _type='batch', _ops=self._ops, is_train=False) s6_conv3 = tf_utils.relu(s6_conv3, name='relu_s6_conv3') # Stage 7 s7_deconv1 = tf_utils.deconv2d(x=s6_conv3, output_dim=self.conv_dims[13], k_h=2, k_w=2, initializer='He', name='s7_deconv1') s7_deconv1 = tf_utils.norm(s7_deconv1, name='s7_norm0', _type='batch', _ops=self._ops, is_train=False) s7_deconv1 = tf_utils.relu(s7_deconv1, name='relu_s7_deconv1') # Concat s7_concat = tf_utils.concat(values=[s7_deconv1, s3_conv2], axis=3, name='s7_axis3_concat') s7_conv2 = tf_utils.conv2d(x=s7_concat, output_dim=self.conv_dims[14], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s7_conv2') s7_conv2 = tf_utils.norm(s7_conv2, name='s7_norm1', _type='batch', _ops=self._ops, is_train=False) s7_conv2 = tf_utils.relu(s7_conv2, name='relu_s7_conv2') s7_conv3 = tf_utils.conv2d(x=s7_conv2, output_dim=self.conv_dims[15], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s7_conv3') s7_conv3 = tf_utils.norm(s7_conv3, name='s7_norm2', _type='batch', _ops=self._ops, is_train=False) s7_conv3 = tf_utils.relu(s7_conv3, name='relu_s7_conv3') # Stage 8 s8_deconv1 = tf_utils.deconv2d(x=s7_conv3, output_dim=self.conv_dims[16], k_h=2, k_w=2, initializer='He', name='s8_deconv1') s8_deconv1 = tf_utils.norm(s8_deconv1, name='s8_norm0', _type='batch', _ops=self._ops, is_train=False) s8_deconv1 = tf_utils.relu(s8_deconv1, name='relu_s8_deconv1') # Concat s8_concat = tf_utils.concat(values=[s8_deconv1, s2_conv2], axis=3, name='s8_axis3_concat') s8_conv2 = tf_utils.conv2d(x=s8_concat, output_dim=self.conv_dims[17], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s8_conv2') s8_conv2 = tf_utils.norm(s8_conv2, name='s8_norm1', _type='batch', _ops=self._ops, is_train=False) s8_conv2 = tf_utils.relu(s8_conv2, name='relu_s8_conv2') s8_conv3 = tf_utils.conv2d(x=s8_conv2, output_dim=self.conv_dims[18], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s8_conv3') s8_conv3 = tf_utils.norm(s8_conv3, name='s8_norm2', _type='batch', _ops=self._ops, is_train=False) s8_conv3 = tf_utils.relu(s8_conv3, name='relu_conv3') # Stage 9 s9_deconv1 = tf_utils.deconv2d(x=s8_conv3, output_dim=self.conv_dims[19], k_h=2, k_w=2, initializer='He', name='s9_deconv1') s9_deconv1 = tf_utils.norm(s9_deconv1, name='s9_norm0', _type='batch', _ops=self._ops, is_train=False) s9_deconv1 = tf_utils.relu(s9_deconv1, name='relu_s9_deconv1') # Concat s9_concat = tf_utils.concat(values=[s9_deconv1, s1_conv2], axis=3, name='s9_axis3_concat') s9_conv2 = tf_utils.conv2d(x=s9_concat, output_dim=self.conv_dims[20], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s9_conv2') s9_conv2 = tf_utils.norm(s9_conv2, name='s9_norm1', _type='batch', _ops=self._ops, is_train=False) s9_conv2 = tf_utils.relu(s9_conv2, name='relu_s9_conv2') s9_conv3 = tf_utils.conv2d(x=s9_conv2, output_dim=self.conv_dims[21], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s9_conv3') s9_conv3 = tf_utils.norm(s9_conv3, name='s9_norm2', _type='batch', _ops=self._ops, is_train=False) s9_conv3 = tf_utils.relu(s9_conv3, name='relu_s9_conv3') s10_deconv1 = tf_utils.deconv2d(x=s9_conv3, output_dim=self.conv_dims[-1], k_h=2, k_w=2, initializer='He', name='s10_deconv1') s10_deconv1 = tf_utils.norm(s10_deconv1, name='s10_norm0', _type='batch', _ops=self._ops, is_train=False) s10_deconv1 = tf_utils.relu(s10_deconv1, name='relu_s10_deconv1') # Concat s10_concat = tf_utils.concat(values=[s10_deconv1, s0_conv2], axis=3, name='s10_axis3_concat') s10_conv2 = tf_utils.conv2d(s10_concat, output_dim=self.conv_dims[-1], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s10_conv2') s10_conv2 = tf_utils.norm(s10_conv2, name='s10_norm1', _type='batch', _ops=self._ops, is_train=False) s10_conv2 = tf_utils.relu(s10_conv2, name='relu_s10_conv2') s10_conv3 = tf_utils.conv2d(x=s10_conv2, output_dim=self.conv_dims[-1], k_h=3, k_w=3, d_h=1, d_w=1, padding=padding, initializer='He', name='s10_conv3') s10_conv3 = tf_utils.norm(s10_conv3, name='s10_norm2', _type='batch', _ops=self._ops, is_train=False) s10_conv3 = tf_utils.relu(s10_conv3, name='relu_s10_conv3') output = tf_utils.conv2d(s10_conv3, output_dim=self.num_classes, k_h=1, k_w=1, d_h=1, d_w=1, padding=padding, initializer='He', name='output') return output
def __call__(self, x): with tf.compat.v1.variable_scope(self.name, reuse=self.reuse): tf_utils.print_activations(x, logger=self.logger) # H1: (320, 200) -> (160, 100) h0_conv2d = tf_utils.conv2d(x, output_dim=self.dis_c[0], initializer='He', logger=self.logger, name='h0_conv2d') h0_lrelu = tf_utils.lrelu(h0_conv2d, logger=self.logger, name='h0_lrelu') # H2: (160, 100) -> (80, 50) h1_conv2d = tf_utils.conv2d(h0_lrelu, output_dim=self.dis_c[1], initializer='He', logger=self.logger, name='h1_conv2d') h1_norm = tf_utils.norm(h1_conv2d, _type=self.norm, _ops=self._ops, logger=self.logger, name='h1_norm') h1_lrelu = tf_utils.lrelu(h1_norm, logger=self.logger, name='h1_lrelu') # H3: (80, 50) -> (40, 25) h2_conv2d = tf_utils.conv2d(h1_lrelu, output_dim=self.dis_c[2], initializer='He', logger=self.logger, name='h2_conv2d') h2_norm = tf_utils.norm(h2_conv2d, _type=self.norm, _ops=self._ops, logger=self.logger, name='h2_norm') h2_lrelu = tf_utils.lrelu(h2_norm, logger=self.logger, name='h2_lrelu') # H4: (40, 25) -> (20, 13) h3_conv2d = tf_utils.conv2d(h2_lrelu, output_dim=self.dis_c[3], initializer='He', logger=self.logger, name='h3_conv2d') h3_norm = tf_utils.norm(h3_conv2d, _type=self.norm, _ops=self._ops, logger=self.logger, name='h3_norm') h3_lrelu = tf_utils.lrelu(h3_norm, logger=self.logger, name='h3_lrelu') # H5: (20, 13) -> (20, 13) output = tf_utils.conv2d(h3_lrelu, output_dim=self.dis_c[4], k_h=3, k_w=3, d_h=1, d_w=1, initializer='He', logger=self.logger, name='output_conv2d') # set reuse=True for next call self.reuse = True self.variables = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) return output
def forward_network(self, input_img, padding='SAME', reuse=False): with tf.compat.v1.variable_scope(self.name, reuse=reuse): tf_utils.print_activations(input_img, logger=self.logger) inputs = self.conv2d_fixed_padding(inputs=input_img, filters=64, kernel_size=7, strides=2, name='conv1') inputs = tf_utils.max_pool(inputs, name='3x3_maxpool', ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], logger=self.logger) inputs = self.block_layer(inputs=inputs, filters=64, block_fn=self.bottleneck_block, blocks=self.layers[0], strides=1, train_mode=self.train_mode, name='block_layer1') inputs = self.block_layer(inputs=inputs, filters=128, block_fn=self.bottleneck_block, blocks=self.layers[1], strides=2, train_mode=self.train_mode, name='block_layer2') inputs = self.block_layer(inputs=inputs, filters=256, block_fn=self.bottleneck_block, blocks=self.layers[2], strides=2, train_mode=self.train_mode, name='block_layer3') inputs = self.block_layer(inputs=inputs, filters=512, block_fn=self.bottleneck_block, blocks=self.layers[3], strides=2, train_mode=self.train_mode, name='block_layer4') inputs = tf_utils.norm(inputs, name='before_gap_batch_norm', _type='batch', _ops=self._ops, is_train=self.train_mode, logger=self.logger) inputs = tf_utils.relu(inputs, name='before_gap_relu', logger=self.logger) _, h, w, _ = inputs.get_shape().as_list() inputs = tf_utils.avg_pool(inputs, name='gap', ksize=[1, h, w, 1], strides=[1, 1, 1, 1], logger=self.logger) inputs = tf_utils.flatten(inputs, name='flatten', logger=self.logger) logits = tf_utils.linear(inputs, self.num_classes, name='logits') return logits
def __call__(self, x): with tf.variable_scope(self.name, reuse=self.reuse): tf_utils.print_activations(x) # conv: (N, H, W, C) -> (N, H/2, W/2, 64) output = tf_utils.conv2d(x, self.conv_dims[0], k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv0_conv2d') output = tf_utils.lrelu(output, name='conv0_lrelu', is_print=True) for idx, conv_dim in enumerate(self.conv_dims[1:]): # conv: (N, H/2, W/2, C) -> (N, H/4, W/4, 2C) output = tf_utils.conv2d(output, conv_dim, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv{}_conv2d'.format(idx + 1)) output = tf_utils.norm(output, _type=self.norm, _ops=self._ops, name='conv{}_norm'.format(idx + 1)) output = tf_utils.lrelu(output, name='conv{}_lrelu'.format(idx + 1), is_print=True) for idx, deconv_dim in enumerate(self.deconv_dims): # deconv: (N, H/16, W/16, C) -> (N, W/8, H/8, C/2) output = tf_utils.deconv2d(output, deconv_dim, k_h=4, k_w=4, name='deconv{}_conv2d'.format(idx)) output = tf_utils.norm(output, _type=self.norm, _ops=self._ops, name='deconv{}_norm'.format(idx)) output = tf_utils.relu(output, name='deconv{}_relu'.format(idx), is_print=True) # split (N, 152, 104, 64) to (N, 150, 104, 64) shapeA = int(self.img_size[0] / 2) shapeB = output.get_shape().as_list()[1] - shapeA output, _ = tf.split(output, [shapeA, shapeB], axis=1, name='split_0') tf_utils.print_activations(output) # split (N, 150, 104, 64) to (N, 150, 100, 64) shapeA = int(self.img_size[1] / 2) shapeB = output.get_shape().as_list()[2] - shapeA output, _ = tf.split(output, [shapeA, shapeB], axis=2, name='split_1') tf_utils.print_activations(output) # conv: (N, H/2, W/2, 64) -> (N, W, H, 3) output = tf_utils.deconv2d(output, self.img_size[2], k_h=4, k_w=4, name='conv3_deconv2d') output = tf_utils.tanh(output, name='conv4_tanh', is_print=True) # set reuse=True for next call self.reuse = True self.variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) return output