def task(x, activation='relu', output_dim=256, scope='task_network', norm='layer', b_train=False): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if activation == 'swish': act_func = util.swish elif activation == 'relu': act_func = tf.nn.relu elif activation == 'lrelu': act_func = tf.nn.leaky_relu else: act_func = tf.nn.sigmoid print('Task Layer1: ' + str(x.get_shape().as_list())) block_depth = dense_block_depth l = x l = layers.conv(l, scope='conv1', filter_dims=[3, 3, block_depth], stride_dims=[1, 1], non_linear_fn=None, bias=False, dilation=[1, 1, 1, 1]) if norm == 'layer': l = layers.layer_norm(l, scope='ln1') elif norm == 'batch': l = layers.batch_norm_conv(l, b_train=b_train, scope='bn1') l = act_func(l) for i in range(15): l = layers.add_residual_block(l, filter_dims=[3, 3, block_depth], num_layers=2, act_func=act_func, norm=norm, b_train=b_train, scope='block1_' + str(i)) latent = layers.global_avg_pool(l, output_length=output_dim) return latent
def decoder_network(latent, anchor_layer=None, activation='swish', scope='g_decoder_network', bn_phaze=False): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if activation == 'swish': act_func = util.swish elif activation == 'relu': act_func = tf.nn.relu elif activation == 'lrelu': act_func = tf.nn.leaky_relu else: act_func = tf.nn.sigmoid #l = tf.cond(bn_phaze, lambda: latent, lambda: make_multi_modal_noise(8)) l = tf.cond(bn_phaze, lambda: latent, lambda: latent) l = layers.fc(l, 6*6*32, non_linear_fn=act_func) print('decoder input:', str(latent.get_shape().as_list())) l = tf.reshape(l, shape=[-1, 6, 6, 32]) l = add_residual_block(l, filter_dims=[3, 3, g_dense_block_depth*4], num_layers=4, act_func=act_func, bn_phaze=bn_phaze, use_residual=False, scope='block_0') print('block 0:', str(l.get_shape().as_list())) l = layers.batch_norm_conv(l, b_train=bn_phaze, scope='bn1') l = act_func(l) # 12 x 12 l = layers.deconv(l, b_size=batch_size, scope='g_dec_deconv1', filter_dims=[3, 3, g_dense_block_depth * 3], stride_dims=[2, 2], padding='SAME', non_linear_fn=None) print('deconv1:', str(l.get_shape().as_list())) l = add_residual_block(l, filter_dims=[3, 3, g_dense_block_depth * 3], num_layers=4, act_func=act_func, bn_phaze=bn_phaze, use_residual=False, scope='block_1', use_dilation=True) l = layers.batch_norm_conv(l, b_train=bn_phaze, scope='bn2') l = act_func(l) # 24 x 24 l = layers.deconv(l, b_size=batch_size, scope='g_dec_deconv2', filter_dims=[3, 3, g_dense_block_depth * 2], stride_dims=[2, 2], padding='SAME', non_linear_fn=None) print('deconv2:', str(l.get_shape().as_list())) l = add_residual_block(l, filter_dims=[3, 3, g_dense_block_depth * 2], num_layers=4, act_func=act_func, bn_phaze=bn_phaze, use_residual=False, scope='block_2', use_dilation=True) l = layers.batch_norm_conv(l, b_train=bn_phaze, scope='bn3') l = act_func(l) # 48 x 48 l = layers.deconv(l, b_size=batch_size, scope='g_dec_deconv3', filter_dims=[3, 3, g_dense_block_depth], stride_dims=[2, 2], padding='SAME', non_linear_fn=None) print('deconv3:', str(l.get_shape().as_list())) l = add_residual_block(l, filter_dims=[3, 3, g_dense_block_depth], num_layers=4, act_func=act_func, bn_phaze=bn_phaze, use_residual=False, scope='block_3', use_dilation=True) l = layers.batch_norm_conv(l, b_train=bn_phaze, scope='bn4') l = act_func(l) l = layers.self_attention(l, g_dense_block_depth, act_func=act_func) if anchor_layer is not None: l = tf.concat([l, anchor_layer], axis=3) # 96 x 96 l = layers.deconv(l, b_size=batch_size, scope='g_dec_deconv4', filter_dims=[3, 3, g_dense_block_depth], stride_dims=[2, 2], padding='SAME', non_linear_fn=None) l = add_residual_block(l, filter_dims=[3, 3, g_dense_block_depth], num_layers=2, act_func=act_func, bn_phaze=bn_phaze, use_residual=False, scope='block_4', use_dilation=True) l = layers.add_dense_transition_layer(l, filter_dims=[1, 1, 3], act_func=act_func, scope='dense_transition_1', bn_phaze=bn_phaze, use_pool=False) l = add_residual_block(l, filter_dims=[3, 3, 3], num_layers=2, act_func=act_func, bn_phaze=bn_phaze, use_residual=False, scope='block_5', use_dilation=True) l = tf.nn.tanh(l) print('final:', str(l.get_shape().as_list())) return l
def encoder_network(x, activation='relu', scope='encoder_network', reuse=False, bn_phaze=False, keep_prob=0.5): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): # if reuse: # tf.get_variable_scope().reuse_variables() if activation == 'swish': act_func = util.swish elif activation == 'relu': act_func = tf.nn.relu elif activation == 'lrelu': act_func = tf.nn.leaky_relu else: act_func = tf.nn.sigmoid # [96 x 96] l = layers.conv(x, scope='conv1', filter_dims=[3, 3, g_dense_block_depth], stride_dims=[1, 1], non_linear_fn=None, bias=False, dilation=[1, 1, 1, 1]) l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth], num_layers=2, act_func=act_func, bn_phaze=bn_phaze, scope='block_0') l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth], num_layers=2, act_func=act_func, bn_phaze=bn_phaze, scope='block_1') l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth], num_layers=2, act_func=act_func, bn_phaze=bn_phaze, scope='block_1_1') l = layers.batch_norm_conv(l, b_train=bn_phaze, scope='bn1') l = act_func(l) # [48 x 48] #l = tf.nn.avg_pool(l, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') l = layers.conv(l, scope='conv2', filter_dims=[3, 3, g_dense_block_depth], stride_dims=[2, 2], non_linear_fn=act_func, bias=False, dilation=[1, 1, 1, 1]) l = layers.self_attention(l, g_dense_block_depth) l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth], num_layers=2, act_func=act_func, bn_phaze=bn_phaze, scope='block_2') l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth], num_layers=2, act_func=act_func, bn_phaze=bn_phaze, scope='block_3') l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth], num_layers=2, act_func=act_func, bn_phaze=bn_phaze, scope='block_3_1') l = layers.batch_norm_conv(l, b_train=bn_phaze, scope='bn2') l = act_func(l) l_share = l # [24 x 24] #l = tf.nn.avg_pool(l, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') l = layers.conv(l, scope='conv3', filter_dims=[3, 3, g_dense_block_depth * 2], stride_dims=[2, 2], non_linear_fn=None, bias=False, dilation=[1, 1, 1, 1]) l = layers.add_dense_transition_layer(l, filter_dims=[1, 1, g_dense_block_depth * 2], act_func=act_func, scope='dense_transition_24', bn_phaze=bn_phaze, use_pool=False) l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth * 2], num_layers=3, act_func=act_func, bn_phaze=bn_phaze, scope='block_4') l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth * 2], num_layers=3, act_func=act_func, bn_phaze=bn_phaze, scope='block_5') l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth * 2], num_layers=3, act_func=act_func, bn_phaze=bn_phaze, scope='block_5_1') l = layers.batch_norm_conv(l, b_train=bn_phaze, scope='bn3') l = act_func(l) # [12 x 12] #l = tf.nn.avg_pool(l, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') l = layers.conv(l, scope='conv4', filter_dims=[3, 3, g_dense_block_depth * 3], stride_dims=[2, 2], non_linear_fn=None, bias=False, dilation=[1, 1, 1, 1]) l = layers.add_dense_transition_layer(l, filter_dims=[1, 1, g_dense_block_depth * 3], act_func=act_func, scope='dense_transition_12', bn_phaze=bn_phaze, use_pool=False) l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth * 3], num_layers=3, act_func=act_func, bn_phaze=bn_phaze, scope='block_6') l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth * 3], num_layers=3, act_func=act_func, bn_phaze=bn_phaze, scope='block_7') l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth * 3], num_layers=3, act_func=act_func, bn_phaze=bn_phaze, scope='block_7_1') l = layers.batch_norm_conv(l, b_train=bn_phaze, scope='bn4') l = act_func(l) # [6 x 6] #l = tf.nn.avg_pool(l, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') l = layers.conv(l, scope='conv5', filter_dims=[3, 3, g_dense_block_depth * 4], stride_dims=[2, 2], non_linear_fn=None, bias=False, dilation=[1, 1, 1, 1]) l = layers.add_dense_transition_layer(l, filter_dims=[1, 1, g_dense_block_depth * 4], act_func=act_func, scope='dense_transition_6', bn_phaze=bn_phaze, use_pool=False) l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth * 4], num_layers=3, act_func=act_func, bn_phaze=bn_phaze, scope='block_8') #l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth * 4], num_layers=3, # act_func=act_func, bn_phaze=bn_phaze, scope='block_9') #l = add_residual_dense_block(l, filter_dims=[3, 3, g_dense_block_depth * 4], num_layers=3, # act_func=act_func, bn_phaze=bn_phaze, scope='block_10') with tf.variable_scope('dense_block_last'): scale_layer = layers.add_dense_transition_layer(l, filter_dims=[1, 1, representation_dim], act_func=act_func, scope='dense_transition_1', bn_phaze=bn_phaze, use_pool=False) last_dense_layer = layers.add_dense_transition_layer(l, filter_dims=[1, 1, representation_dim], act_func=act_func, scope='dense_transition_2', bn_phaze=bn_phaze, use_pool=False) scale_layer = act_func(scale_layer) last_dense_layer = act_func(last_dense_layer) return last_dense_layer, scale_layer, l_share
def encoder(x, activation='relu', scope='encoder_network', norm='layer', b_train=False): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if activation == 'swish': act_func = util.swish elif activation == 'relu': act_func = tf.nn.relu elif activation == 'lrelu': act_func = tf.nn.leaky_relu else: act_func = tf.nn.sigmoid # [192 x 192] block_depth = dense_block_depth // 4 l = layers.conv(x, scope='conv1', filter_dims=[5, 5, block_depth], stride_dims=[1, 1], non_linear_fn=None, bias=False, dilation=[1, 1, 1, 1]) if norm == 'layer': l = layers.layer_norm(l, scope='ln0') elif norm == 'batch': l = layers.batch_norm_conv(l, b_train=b_train, scope='bn0') l = act_func(l) for i in range(4): l = layers.add_residual_dense_block(l, filter_dims=[3, 3, block_depth], num_layers=2, act_func=act_func, norm=norm, b_train=b_train, scope='dense_block_1_' + str(i)) # [64 x 64] block_depth = block_depth * 2 l = layers.conv(l, scope='tr1', filter_dims=[3, 3, block_depth], stride_dims=[2, 2], non_linear_fn=None) if norm == 'layer': l = layers.layer_norm(l, scope='ln1') elif norm == 'batch': l = layers.batch_norm_conv(l, b_train=b_train, scope='bn1') l = act_func(l) print('Encoder Block 1: ' + str(l.get_shape().as_list())) for i in range(2): l = layers.add_residual_block(l, filter_dims=[3, 3, block_depth], num_layers=2, act_func=act_func, norm=norm, b_train=b_train, scope='res_block_1_' + str(i)) # [32 x 32] block_depth = block_depth * 2 l = layers.conv(l, scope='tr2', filter_dims=[3, 3, block_depth], stride_dims=[2, 2], non_linear_fn=None) if norm == 'layer': l = layers.layer_norm(l, scope='ln2') elif norm == 'batch': l = layers.batch_norm_conv(l, b_train=b_train, scope='bn2') l = act_func(l) print('Encoder Block 2: ' + str(l.get_shape().as_list())) for i in range(2): l = layers.add_residual_block(l, filter_dims=[3, 3, block_depth], num_layers=2, act_func=act_func, norm=norm, b_train=b_train, scope='res_block_2_' + str(i)) # [16 x 16] block_depth = block_depth * 2 l = layers.conv(l, scope='tr3', filter_dims=[3, 3, block_depth], stride_dims=[2, 2], non_linear_fn=None) if norm == 'layer': l = layers.layer_norm(l, scope='ln3') elif norm == 'batch': l = layers.batch_norm_conv(l, b_train=b_train, scope='bn3') l = act_func(l) print('Encoder Block 3: ' + str(l.get_shape().as_list())) for i in range(2): l = layers.add_residual_block(l, filter_dims=[3, 3, block_depth], num_layers=2, act_func=act_func, norm=norm, b_train=b_train, scope='res_block_3' + str(i)) # [8 x 8] block_depth = block_depth * 2 l = layers.conv(l, scope='tr4', filter_dims=[3, 3, block_depth], stride_dims=[2, 2], non_linear_fn=None) if norm == 'layer': l = layers.layer_norm(l, scope='ln4') elif norm == 'batch': l = layers.batch_norm_conv(l, b_train=b_train, scope='bn4') l = act_func(l) print('Encoder Block 4: ' + str(l.get_shape().as_list())) for i in range(2): l = layers.add_residual_block(l, filter_dims=[3, 3, block_depth], num_layers=2, act_func=act_func, norm=norm, b_train=b_train, use_dilation=True, scope='res_block_4_' + str(i)) # [4 x 4] block_depth = block_depth * 2 l = layers.conv(l, scope='tr5', filter_dims=[3, 3, block_depth], stride_dims=[2, 2], non_linear_fn=None) print('Encoder Block 5: ' + str(l.get_shape().as_list())) if norm == 'layer': l = layers.layer_norm(l, scope='ln5') elif norm == 'batch': l = layers.batch_norm_conv(l, b_train=b_train, scope='bn5') l = act_func(l) for i in range(2): l = layers.add_residual_block(l, filter_dims=[3, 3, block_depth], num_layers=2, act_func=act_func, norm=norm, b_train=b_train, use_dilation=True, scope='res_block_5_' + str(i)) last_layer = l context = layers.global_avg_pool(last_layer, output_length=representation_dim, use_bias=True, scope='gp') print('Encoder GP Dims: ' + str(context.get_shape().as_list())) context = tf.reshape(context, [batch_size, num_context_patches, num_context_patches, -1]) print('Context Dims: ' + str(context.get_shape().as_list())) return context