Exemplo n.º 1
0
 def channel_shuffle(cls, x, num_groups, name='_shuffle'):
     with tf.variable_scope(name):
         n, h, w, c = x.outputs.get_shape()
         x_reshaped = ReshapeLayer(x, (-1, h, w, num_groups, c // num_groups))# 先合并重组
         x_transposed = TransposeLayer(x_reshaped, [0, 1, 2, 4, 3])# 转置
         output = ReshapeLayer(x_transposed, (-1, h, w, c))# 摊平
         return output
Exemplo n.º 2
0
    def __get_network__(self,
                        encode_seq,
                        decode_seq,
                        query_decode_seq,
                        is_train=True,
                        reuse=False):

        w_init = tf.random_normal_initializer(stddev=0.02)
        g_init = tf.random_normal_initializer(1., 0.02)

        with tf.variable_scope(self.model_name, reuse=reuse) as vs:
            tl.layers.set_name_reuse(reuse)
            net_encode_traffic = InputLayer(encode_seq, name='in_root_net')
            net_encode_query = InputLayer(self.query_x, name="in_query_net")
            net_encode = ConcatLayer([net_encode_traffic, net_encode_query],
                                     concat_dim=-1,
                                     name="encode")

            net_decode_traffic = InputLayer(decode_seq, name="decode_root")
            net_decode_query = InputLayer(query_decode_seq,
                                          name="decode_query_net")
            net_decode = ConcatLayer([net_decode_traffic, net_decode_query],
                                     concat_dim=-1,
                                     name="decode")

            net_rnn = Seq2Seq(
                net_encode,
                net_decode,
                cell_fn=tf.contrib.rnn.BasicLSTMCell,
                n_hidden=config.dim_hidden,
                initializer=tf.random_uniform_initializer(-0.1, 0.1),
                encode_sequence_length=tl.layers.retrieve_seq_length_op(
                    net_encode.outputs),
                decode_sequence_length=tl.layers.retrieve_seq_length_op(
                    net_decode.outputs),
                initial_state_encode=None,
                # dropout=(0.8 if is_train else None),
                dropout=None,
                n_layer=1,
                return_seq_2d=True,
                name='seq2seq')
            # net_out = DenseLayer(net_rnn, n_units=64, act=tf.identity, name='dense1')
            net_out = DenseLayer(net_rnn,
                                 n_units=1,
                                 act=tf.identity,
                                 name='dense2')
            if is_train:
                net_out = ReshapeLayer(
                    net_out, (config.batch_size, config.out_seq_length + 1, 1),
                    name="reshape_out")
            else:
                net_out = ReshapeLayer(net_out, (config.batch_size, 1, 1),
                                       name="reshape_out")

            self.net_rnn = net_rnn

            return net_out
def mobilenet(x, is_train=True, reuse=False):
    with tf.variable_scope("mobilenet", reuse=reuse):
        n = InputLayer(x)
        n = conv_block(n, 32, strides=(2, 2), is_train=is_train, name="conv")
        n = depthwise_conv_block(n, 64, is_train=is_train, name="depth1")

        n = depthwise_conv_block(n, 128, strides=(2, 2), is_train=is_train, name="depth2")
        n = depthwise_conv_block(n, 128, is_train=is_train, name="depth3")

        n = depthwise_conv_block(n, 256, strides=(2, 2), is_train=is_train, name="depth4")
        n = depthwise_conv_block(n, 256, is_train=is_train, name="depth5")

        n = depthwise_conv_block(n, 512, strides=(2, 2), is_train=is_train, name="depth6")
        n = depthwise_conv_block(n, 512, is_train=is_train, name="depth7")
        n = depthwise_conv_block(n, 512, is_train=is_train, name="depth8")
        n = depthwise_conv_block(n, 512, is_train=is_train, name="depth9")
        n = depthwise_conv_block(n, 512, is_train=is_train, name="depth10")
        n = depthwise_conv_block(n, 512, is_train=is_train, name="depth11")

        n = depthwise_conv_block(n, 1024, strides=(2, 2), is_train=is_train, name="depth12")
        n = depthwise_conv_block(n, 1024, is_train=is_train, name="depth13")

        n = GlobalMeanPool2d(n)
        # n = DropoutLayer(n, 1-1e-3, True, is_train, name='drop')
        # n = DenseLayer(n, 1000, act=tf.identity, name='output')   # equal
        n = ReshapeLayer(n, [-1, 1, 1, 1024])
        n = Conv2d(n, 1000, (1, 1), (1, 1), name='out')
        n = FlattenLayer(n)
    return n
Exemplo n.º 4
0
def generator(inputs, is_train=True, reuse=False):
    image_size = 128
   
    gf_dim = 64    # Dimension of gen filters in first conv layer. [64]
    c_dim = 1    # n_color 1
    w_init = tf.glorot_normal_initializer()
    gamma_init = tf.random_normal_initializer(1., 0.02)

    with tf.name_scope("GENERATOR"):

        with tf.variable_scope("generator", reuse=reuse):


            with tf.name_scope("net_in"):
                net_in = InputLayer(inputs, name='g/in')
        #############################################################################
            with tf.name_scope("layer0"):
                net_h0 = DenseLayer(net_in, n_units=(gf_dim * 32 * 4 * 4), W_init=w_init,
                act = tf.identity, name='g/h0/lin')
                net_h0 = ReshapeLayer(net_h0, shape=[-1, 4, 4, gf_dim * 32], name='g/h0/reshape')
                net_h0 = BatchNormLayer(net_h0, decay=0.9, act=tf.nn.relu, is_train=is_train,
                gamma_init=gamma_init, name='g/h0/batch_norm')

            with tf.name_scope("layer1"):
                net_h1 = DeConv2d(net_h0, gf_dim * 8, (5, 5), strides=(2, 2),
                padding='SAME', act=None, W_init=w_init, name='g/h1/decon2d')
                net_h1 = BatchNormLayer(net_h1, decay=0.9, act=tf.nn.relu, is_train=is_train,
                gamma_init=gamma_init, name='g/h1/batch_norm')

            with tf.name_scope("layer2"):
                net_h2 = DeConv2d(net_h1, gf_dim * 4, (5, 5), strides=(2, 2),
                padding='SAME', act=None, W_init=w_init, name='g/h2/decon2d')
                net_h2 = BatchNormLayer(net_h2, decay=0.9, act=tf.nn.relu, is_train=is_train,
                gamma_init=gamma_init, name='g/h2/batch_norm')

            with tf.name_scope("layer3"):
                net_h3 = DeConv2d(net_h2, gf_dim*2, (5, 5), strides=(2, 2),
                padding='SAME', act=None, W_init=w_init, name='g/h3/decon2d')
                net_h3 = BatchNormLayer(net_h3, decay=0.9, act=tf.nn.relu, is_train=is_train,
                gamma_init=gamma_init, name='g/h3/batch_norm')


            with tf.name_scope("layer4"):
                net_h4 = DeConv2d(net_h3, gf_dim, (5, 5), strides=(2, 2),
                padding='SAME', act=None, W_init=w_init, name='g/h4/decon2d')
                net_h4 = BatchNormLayer(net_h4, decay=0.9, act=tf.nn.relu, is_train=is_train,
                gamma_init=gamma_init, name='g/h4/batch_norm')

            with tf.name_scope("layer5"):
                net_h5 = DeConv2d(net_h4, c_dim, (5, 5), strides=(2, 2),
                padding='SAME', act=None, W_init=w_init, name='g/h5/decon2d')
        #net_h5.outputs = tf.nn.tanh(net_h5.outputs)
                net_h5.outputs = tf.nn.tanh(net_h5.outputs)

        return net_h5
Exemplo n.º 5
0
    def mobilenetv1(self, x, end_with='out', is_train=False, reuse=None):
        with tf.variable_scope("mobilenetv1", reuse=reuse):
            n = InputLayer(x)
            n = self.conv_block(n, 32, strides=(2, 2), is_train=is_train, name="conv")
            if end_with in n.outputs.name: return n
            n = self.depthwise_conv_block(n, 64, is_train=is_train, name="depth1")
            if end_with in n.outputs.name: return n

            n = self.depthwise_conv_block(n, 128, strides=(2, 2), is_train=is_train, name="depth2")
            if end_with in n.outputs.name: return n
            n = self.depthwise_conv_block(n, 128, is_train=is_train, name="depth3")
            if end_with in n.outputs.name: return n

            n = self.depthwise_conv_block(n, 256, strides=(2, 2), is_train=is_train, name="depth4")
            if end_with in n.outputs.name: return n
            n = self.depthwise_conv_block(n, 256, is_train=is_train, name="depth5")
            if end_with in n.outputs.name: return n

            n = self.depthwise_conv_block(n, 512, strides=(2, 2), is_train=is_train, name="depth6")
            if end_with in n.outputs.name: return n
            n = self.depthwise_conv_block(n, 512, is_train=is_train, name="depth7")
            if end_with in n.outputs.name: return n
            n = self.depthwise_conv_block(n, 512, is_train=is_train, name="depth8")
            if end_with in n.outputs.name: return n
            n = self.depthwise_conv_block(n, 512, is_train=is_train, name="depth9")
            if end_with in n.outputs.name: return n
            n = self.depthwise_conv_block(n, 512, is_train=is_train, name="depth10")
            if end_with in n.outputs.name: return n
            n = self.depthwise_conv_block(n, 512, is_train=is_train, name="depth11")
            if end_with in n.outputs.name: return n

            n = self.depthwise_conv_block(n, 1024, strides=(2, 2), is_train=is_train, name="depth12")
            if end_with in n.outputs.name: return n
            n = self.depthwise_conv_block(n, 1024, is_train=is_train, name="depth13")
            if end_with in n.outputs.name: return n

            n = GlobalMeanPool2d(n, name='globalmeanpool')
            if end_with in n.outputs.name: return n
            # n = DropoutLayer(n, 1-1e-3, True, is_train, name='drop')
            # n = DenseLayer(n, 1000, act=tf.identity, name='output')   # equal
            n = ReshapeLayer(n, [-1, 1, 1, 1024], name='reshape')
            if end_with in n.outputs.name: return n
            n = Conv2d(n, 1000, (1, 1), (1, 1), name='out')
            n = FlattenLayer(n, name='flatten')
            if end_with == 'out': return n

            raise Exception("end_with : conv, depth1, depth2 ... depth13, globalmeanpool, out")
Exemplo n.º 6
0
    def generator(self, z, label_class, is_train=True, reuse=False):
        # NOTE: concate z & label might be wrong, need to test
        labels_one_hot = tf.one_hot(label_class, self.class_num)
        z_labels = tf.concat([z, labels_one_hot], 1)
        image_size = self.images_size
        s16 = image_size // 16
        gf_dim = 64    # Dimension of gen filters in first conv layer. [64]
        c_dim = self.channel    # n_color 3
        w_init = tf.glorot_normal_initializer()
        gamma_init = tf.random_normal_initializer(1., 0.02)

        with tf.variable_scope("generator", reuse=reuse):
            net_in = InputLayer(z_labels, name='g/in')
            net_h0 = DenseLayer(net_in, n_units=(gf_dim * 8 * s16 * s16), W_init=w_init,
                    act = tf.identity, name='g/h0/lin')
            net_h0 = ReshapeLayer(net_h0, shape=[-1, s16, s16, gf_dim*8], name='g/h0/reshape')
            net_h0 = BatchNormLayer(net_h0, decay=0.9, act=tf.nn.relu, is_train=is_train,
                    gamma_init=gamma_init, name='g/h0/batch_norm')

            net_h1 = DeConv2d(net_h0, gf_dim * 4, (5, 5), strides=(2, 2),
                    padding='SAME', act=None, W_init=w_init, name='g/h1/decon2d')
            net_h1 = BatchNormLayer(net_h1, decay=0.9, act=tf.nn.relu, is_train=is_train,
                    gamma_init=gamma_init, name='g/h1/batch_norm')

            net_h2 = DeConv2d(net_h1, gf_dim * 2, (5, 5), strides=(2, 2),
                    padding='SAME', act=None, W_init=w_init, name='g/h2/decon2d')
            net_h2 = BatchNormLayer(net_h2, decay=0.9, act=tf.nn.relu, is_train=is_train,
                    gamma_init=gamma_init, name='g/h2/batch_norm')

            net_h3 = DeConv2d(net_h2, gf_dim, (5, 5), strides=(2, 2),
                    padding='SAME', act=None, W_init=w_init, name='g/h3/decon2d')
            net_h3 = BatchNormLayer(net_h3, decay=0.9, act=tf.nn.relu, is_train=is_train,
                    gamma_init=gamma_init, name='g/h3/batch_norm')

            net_h4 = DeConv2d(net_h3, c_dim, (5, 5), strides=(2, 2),
                    padding='SAME', act=None, W_init=w_init, name='g/h4/decon2d')
            net_h4.outputs = tf.nn.tanh(net_h4.outputs)
        return net_h4
Exemplo n.º 7
0
def generator(input_placeholder, train_mode, image_size, reuse=False):
    s2, s4, s8, s16 = int(image_size / 2), int(image_size / 4), int(
        image_size / 8), int(image_size / 16)
    gf_dim = 32

    w_init = tf.random_normal_initializer(stddev=0.02)
    gamma_init = tf.random_normal_initializer(1., 0.02)

    with tf.variable_scope("decoder", reuse=reuse):
        tl.layers.set_name_reuse(reuse)

        input_layer = InputLayer(input_placeholder, name='dec/input')
        lin_layer = DenseLayer(input_layer,
                               n_units=gf_dim * 8 * s16 * s16,
                               W_init=w_init,
                               act=tf.identity,
                               name='dec/lin')
        # lin_layer.shape = (batch_size,256*4*4)
        resh1_layer = ReshapeLayer(lin_layer,
                                   shape=[-1, s16, s16, gf_dim * 8],
                                   name='decoder/reshape')
        # resh1_layer.shape = (batch_size, 4, 4, 256)
        in_bn_layer = BatchNormLayer(resh1_layer,
                                     act=lambda x: tl.act.lrelu(x, 0.2),
                                     is_train=train_mode,
                                     gamma_init=gamma_init,
                                     name='dec/in_bn')

        # upsampling
        up1_layer = UpSampling2dLayer(in_bn_layer,
                                      size=[s8, s8],
                                      is_scale=False,
                                      method=ResizeMethod.NEAREST_NEIGHBOR,
                                      align_corners=False,
                                      name='dec/up1')
        conv1_layer = Conv2d(up1_layer,
                             gf_dim * 4, (3, 3), (1, 1),
                             padding='SAME',
                             W_init=w_init,
                             name='dec/conv1')
        bn1_layer = BatchNormLayer(conv1_layer,
                                   act=lambda x: tl.act.lrelu(x, 0.2),
                                   is_train=train_mode,
                                   gamma_init=gamma_init,
                                   name='dec/bn1')
        # bn1_layer.shape = (batch_size,8,8,128)

        up2_layer = UpSampling2dLayer(bn1_layer,
                                      size=[s4, s4],
                                      is_scale=False,
                                      method=ResizeMethod.NEAREST_NEIGHBOR,
                                      align_corners=False,
                                      name='dec/up2')
        conv2_layer = Conv2d(up2_layer,
                             gf_dim * 2, (3, 3), (1, 1),
                             padding='SAME',
                             W_init=w_init,
                             name='dec/conv2')
        bn2_layer = BatchNormLayer(conv2_layer,
                                   act=lambda x: tl.act.lrelu(x, 0.2),
                                   is_train=train_mode,
                                   gamma_init=gamma_init,
                                   name='dec/bn2')
        # bn2_layer.shape = (batch_size,16,16,64)

        up3_layer = UpSampling2dLayer(bn2_layer,
                                      size=[s2, s2],
                                      is_scale=False,
                                      method=ResizeMethod.NEAREST_NEIGHBOR,
                                      align_corners=False,
                                      name='dec/up3')
        conv3_layer = Conv2d(up3_layer,
                             gf_dim, (3, 3), (1, 1),
                             padding='SAME',
                             W_init=w_init,
                             name='dec/conv3')
        bn3_layer = BatchNormLayer(conv3_layer,
                                   act=lambda x: tl.act.lrelu(x, 0.2),
                                   is_train=train_mode,
                                   gamma_init=gamma_init,
                                   name='dec/bn3_layer')
        # bn3_layer.shape = (batch_size,32,32,32)

        # no BN on last deconv
        up4_layer = UpSampling2dLayer(bn3_layer,
                                      size=[image_size, image_size],
                                      is_scale=False,
                                      method=ResizeMethod.NEAREST_NEIGHBOR,
                                      align_corners=False,
                                      name='dec/up4')
        conv4_layer = Conv2d(up4_layer,
                             3, (3, 3), (1, 1),
                             padding='SAME',
                             W_init=w_init,
                             name='dec/conv4')
        # conv4_layer.shape = (batch_size,64,64,3)
        logits = conv4_layer.outputs
        conv4_layer.outputs = tf.nn.tanh(conv4_layer.outputs)
    return conv4_layer, logits
Exemplo n.º 8
0
def generator(inputs, is_train=True):
    with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
        net_in = InputLayer(inputs, name='gin')

        gnet_d0 = DenseLayer(net_in,
                             n_units=(16384),
                             act=tf.identity,
                             name='gnet_d0')
        gnet_r0 = ReshapeLayer(gnet_d0, shape=[-1, 4, 4, 1024], name='gnet_r0')
        gnet_b0 = BatchNormLayer(gnet_r0,
                                 decay=0.9,
                                 act=tf.nn.relu,
                                 is_train=is_train,
                                 name='gnet_b0')

        gnet_dc1 = DeConv2d(gnet_b0,
                            256, (8, 8),
                            strides=(2, 2),
                            padding='SAME',
                            act=None,
                            name='gnet_dc1')
        gnet_b1 = BatchNormLayer(gnet_dc1,
                                 decay=0.9,
                                 act=tf.nn.relu,
                                 is_train=is_train,
                                 name='gnet_b1')

        gnet_dc2 = DeConv2d(gnet_b1,
                            128, (8, 8),
                            strides=(2, 2),
                            padding='SAME',
                            act=None,
                            name='gnet_dc2')
        gnet_b2 = BatchNormLayer(gnet_dc2,
                                 decay=0.9,
                                 act=tf.nn.relu,
                                 is_train=is_train,
                                 name='gnet_b2')

        gnet_dc3 = DeConv2d(gnet_b2,
                            64, (8, 8),
                            strides=(2, 2),
                            padding='SAME',
                            act=None,
                            name='gnet_dc3')
        gnet_b3 = BatchNormLayer(gnet_dc3,
                                 decay=0.9,
                                 act=tf.nn.relu,
                                 is_train=is_train,
                                 name='gnet_b3')

        gnet_dc4 = DeConv2d(gnet_b3,
                            3, (8, 8),
                            strides=(2, 2),
                            padding='SAME',
                            act=None,
                            name='net_h4')

        #Based on the paper, we need to provide non-linearity to the generated image
        #TODO: Why?
        gnet_dc4.outputs = tf.nn.tanh(gnet_dc4.outputs)
    return gnet_dc4
    def __get_network__(self, model_name, encode_seqs, class_label_seqs, kg_vector, reuse=False, is_train=True):
        with tf.variable_scope(model_name, reuse=reuse):
            tl.layers.set_name_reuse(reuse)

            net_word_embed = InputLayer(
                inputs=encode_seqs,
                name="in_word_embed"
            )

            net_class_label_embed = InputLayer(
                inputs=class_label_seqs,
                name="in_class_label_embed"
            )

            net_kg = InputLayer(
                inputs=kg_vector,
                name='in_kg'
            )

            net_kg = ReshapeLayer(
                net_kg,
                shape=(-1, self.kg_embedding_dim),
                name="reshape_kg_1"
            )

            net_kg = ReshapeLayer(
                net_kg,
                shape=(-1, self.max_length, self.kg_embedding_dim),
                name="reshape_kg_2"
            )

            if config.model == "vwvcvkg":
                # dbpedia and 20news
                net_in = ConcatLayer(
                    [net_word_embed, net_class_label_embed, net_kg],
                    concat_dim=-1,
                    name='concat_vw_vwc_vc'
                )
            elif config.model == "vwvc":
                net_in = ConcatLayer(
                    [net_word_embed, net_class_label_embed],
                    concat_dim=-1,
                    name='concat_vw_vc'
                )
            elif config.model == "vwvkg":
                net_in = ConcatLayer(
                    [net_word_embed, net_kg],
                    concat_dim=-1,
                    name='concat_vw_vwc'
                )
            elif config.model == "vcvkg":
                net_in = ConcatLayer(
                    [net_class_label_embed, net_kg],
                    concat_dim=-1,
                    name='concat_vc_vwc'
                )
            elif config.model == "kgonly":
                net_in = ConcatLayer(
                    [net_kg],
                    concat_dim=-1,
                    name='concat_vwc'
                )
            else:
                raise Exception("config.model value error")

            filter_length = [2, 4, 8]
            # dbpedia
            n_filter = 600
            # n_filter = 200

            net_cnn_list = list()

            for fsz in filter_length:

                net_cnn = Conv1d(
                    net_in,
                    n_filter=n_filter,
                    filter_size=fsz,
                    stride=1,
                    act=tf.nn.relu,
                    name="cnn%d" % fsz
                )
                net_cnn.outputs = tf.reduce_max(net_cnn.outputs, axis=1, name="global_maxpool%d" % fsz)
                net_cnn_list.append(net_cnn)

            '''
            if config.model == "vwvcvkg":
                net_class_label_embed.outputs = tf.slice(
                    net_class_label_embed.outputs,
                    [0, 0, 0],
                    [config.batch_size, 1, self.word_embedding_dim],
                    name="slice_word"
                )
                net_class_label_embed.outputs = tf.squeeze(
                    net_class_label_embed.outputs,
                    name="squeeze_word"
                )
                net_cnn = ConcatLayer(net_cnn_list + [net_class_label_embed], concat_dim=-1)
            else:
                net_cnn = ConcatLayer(net_cnn_list, concat_dim=-1)
            '''
            net_cnn = ConcatLayer(net_cnn_list, concat_dim=-1)

            net_fc = DropoutLayer(net_cnn, keep=0.5, is_fix=True, is_train=is_train, name='drop1')

            net_fc = DenseLayer(
                net_fc,
                n_units=400,
                act=tf.nn.relu,
                name="fc_1"
            )

            net_fc = DropoutLayer(net_fc, keep=0.5, is_fix=True, is_train=is_train, name='drop2')

            # dbpedia
            net_fc = DenseLayer(
                net_fc,
                n_units=100,
                act=tf.nn.relu,
                name="fc_2"
            )
            net_fc = DropoutLayer(net_fc, keep=0.5, is_fix=True, is_train=is_train, name='drop3')

            net_fc = DenseLayer(
                net_fc,
                n_units=1,
                act=tf.nn.sigmoid,
                name="fc_3"
            )
        return net_fc, net_cnn
Exemplo n.º 10
0
def generator(input_placeholder,
              train_mode,
              image_size,
              batch_size,
              reuse=False,
              filters_num=128):

    w_init = tf.random_normal_initializer(stddev=0.02)
    gamma_init = tf.random_normal_initializer(1., 0.02)

    s2, s4, s8, s16 = int(image_size / 2), int(image_size / 4), int(
        image_size / 8), int(image_size / 16)

    with tf.variable_scope("generator", reuse=reuse):
        tl.layers.set_name_reuse(reuse)

        input_layer = InputLayer(input_placeholder, name='gen/in')
        lin_layer = DenseLayer(input_layer,
                               n_units=filters_num * 8 * s16 * s16,
                               W_init=w_init,
                               act=tf.identity,
                               name='gen/lin')

        resh1_layer = ReshapeLayer(lin_layer,
                                   shape=[-1, s16, s16, filters_num * 8],
                                   name='gen/reshape')

        in_bn_layer = BatchNormLayer(resh1_layer,
                                     act=tf.nn.relu,
                                     is_train=train_mode,
                                     gamma_init=gamma_init,
                                     name='dec/in_bn')
        # in_bn_layer.shape = (batch_size, 4, 4, 1024)
        up1_layer = DeConv2d(in_bn_layer,
                             filters_num * 4, (5, 5),
                             out_size=(s8, s8),
                             strides=(2, 2),
                             padding='SAME',
                             batch_size=batch_size,
                             act=None,
                             W_init=w_init,
                             name='gen/up1')

        bn1_layer = BatchNormLayer(up1_layer,
                                   act=tf.nn.relu,
                                   is_train=train_mode,
                                   gamma_init=gamma_init,
                                   name='dec/bn1')

        # bn1_layer.shape = (batch_size, 8, 8, 512)
        up2_layer = DeConv2d(bn1_layer,
                             filters_num * 2, (5, 5),
                             out_size=(s4, s4),
                             strides=(2, 2),
                             padding='SAME',
                             batch_size=batch_size,
                             act=None,
                             W_init=w_init,
                             name='gen/up2')
        bn2_layer = BatchNormLayer(up2_layer,
                                   act=tf.nn.relu,
                                   is_train=train_mode,
                                   gamma_init=gamma_init,
                                   name='dec/bn2')
        # bn2_layer.shape = (batch_size, 16, 16, 256)

        up3_layer = DeConv2d(bn2_layer,
                             filters_num, (5, 5),
                             out_size=(s2, s2),
                             strides=(2, 2),
                             padding='SAME',
                             batch_size=batch_size,
                             act=None,
                             W_init=w_init,
                             name='gen/up3')
        bn3_layer = BatchNormLayer(up3_layer,
                                   act=tf.nn.relu,
                                   is_train=train_mode,
                                   gamma_init=gamma_init,
                                   name='dec/bn3')
        # bn3_layer.shape = (batch_size, 32, 32, 128)
        up4_layer = DeConv2d(bn3_layer,
                             3, (5, 5),
                             out_size=(image_size, image_size),
                             strides=(2, 2),
                             padding='SAME',
                             batch_size=batch_size,
                             act=None,
                             W_init=w_init,
                             name='gen/up4')

        up4_layer.outputs = tf.nn.tanh(up4_layer.outputs)

    return up4_layer, up4_layer.outputs
Exemplo n.º 11
0
def generator(inputs, is_train=True, reuse=False):
    image_size = 64
    s16 = image_size // 16
    gf_dim = 64  # Dimension of gen filters in first conv layer. [64]
    c_dim = FLAGS.c_dim  # n_color 3
    w_init = tf.glorot_normal_initializer()
    gamma_init = tf.random_normal_initializer(1., 0.02)

    with tf.variable_scope("generator", reuse=reuse):

        net_in = InputLayer(inputs, name='g/in')
        net_h0 = DenseLayer(net_in,
                            n_units=(gf_dim * 8 * s16 * s16),
                            W_init=w_init,
                            act=tf.identity,
                            name='g/h0/lin')
        net_h0 = ReshapeLayer(net_h0,
                              shape=[-1, s16, s16, gf_dim * 8],
                              name='g/h0/reshape')
        net_h0 = BatchNormLayer(net_h0,
                                act=tf.nn.relu,
                                is_train=is_train,
                                gamma_init=gamma_init,
                                name='g/h0/batch_norm')

        net_h1 = DeConv2d(net_h0,
                          gf_dim * 4, (5, 5),
                          strides=(2, 2),
                          padding='SAME',
                          act=None,
                          W_init=w_init,
                          name='g/h1/decon2d')
        net_h1 = BatchNormLayer(net_h1,
                                act=tf.nn.relu,
                                is_train=is_train,
                                gamma_init=gamma_init,
                                name='g/h1/batch_norm')

        net_h2 = DeConv2d(net_h1,
                          gf_dim * 2, (5, 5),
                          strides=(2, 2),
                          padding='SAME',
                          act=None,
                          W_init=w_init,
                          name='g/h2/decon2d')
        net_h2 = BatchNormLayer(net_h2,
                                act=tf.nn.relu,
                                is_train=is_train,
                                gamma_init=gamma_init,
                                name='g/h2/batch_norm')

        net_h3 = DeConv2d(net_h2,
                          gf_dim, (5, 5),
                          strides=(2, 2),
                          padding='SAME',
                          act=None,
                          W_init=w_init,
                          name='g/h3/decon2d')
        net_h3 = BatchNormLayer(net_h3,
                                act=tf.nn.relu,
                                is_train=is_train,
                                gamma_init=gamma_init,
                                name='g/h3/batch_norm')

        net_h4 = DeConv2d(net_h3,
                          c_dim, (5, 5),
                          strides=(2, 2),
                          padding='SAME',
                          act=None,
                          W_init=w_init,
                          name='g/h4/decon2d')
        net_h4.outputs = tf.nn.tanh(net_h4.outputs)
    return net_h4
Exemplo n.º 12
0
    def __get_network__(self,
                        encode_seq,
                        neighbour_seq,
                        decode_seq,
                        features,
                        features_full,
                        is_train=True,
                        reuse=False):
        w_init = tf.random_normal_initializer(stddev=0.02)
        g_init = tf.random_normal_initializer(1., 0.02)

        with tf.variable_scope(self.model_name + "_spatial",
                               reuse=reuse) as vs:
            tl.layers.set_name_reuse(reuse)
            inputs_x_root = InputLayer(encode_seq, name='in_root')
            inputs_x_nbor = InputLayer(neighbour_seq, name="in_neighbour")

            # encoding neighbour graph information
            n = ReshapeLayer(inputs_x_nbor,
                             (config.batch_size * config.in_seq_length,
                              config.num_neighbour), "reshape1")
            n.outputs = tf.expand_dims(n.outputs, axis=-1)
            n = Conv1d(n,
                       4,
                       4,
                       1,
                       act=tf.identity,
                       padding='SAME',
                       W_init=w_init,
                       name='conv1')
            n = BatchNormLayer(n,
                               act=tf.nn.relu,
                               is_train=is_train,
                               gamma_init=g_init,
                               name='bn1')
            n = MaxPool1d(n, 2, 2, padding='valid', name='maxpool1')
            n = FlattenLayer(n, name="flatten1")
            n = ReshapeLayer(n, (config.batch_size, config.in_seq_length, -1),
                             name="reshape1_back")

            net_encode = ConcatLayer([inputs_x_root, n],
                                     concat_dim=-1,
                                     name="encode")
            net_decode = InputLayer(decode_seq, name="decode")

            net_rnn = Seq2Seq(
                net_encode,
                net_decode,
                cell_fn=tf.contrib.rnn.BasicLSTMCell,
                n_hidden=config.dim_hidden,
                initializer=tf.random_uniform_initializer(-0.1, 0.1),
                encode_sequence_length=tl.layers.retrieve_seq_length_op(
                    net_encode.outputs),
                decode_sequence_length=tl.layers.retrieve_seq_length_op(
                    net_decode.outputs),
                initial_state_encode=None,
                # dropout=(0.8 if is_train else None),
                dropout=None,
                n_layer=1,
                return_seq_2d=True,
                name='seq2seq')
            net_rnn_seq2seq = net_rnn

            net_spatial_out = DenseLayer(net_rnn,
                                         n_units=1,
                                         act=tf.identity,
                                         name='dense2')
            if is_train:
                net_spatial_out = ReshapeLayer(
                    net_spatial_out,
                    (config.batch_size, config.out_seq_length + 1, 1),
                    name="reshape_out")
            else:
                net_spatial_out = ReshapeLayer(net_spatial_out,
                                               (config.batch_size, 1, 1),
                                               name="reshape_out")

        with tf.variable_scope(self.model_name + "_wide", reuse=reuse) as vs:
            tl.layers.set_name_reuse(reuse)
            # Features
            net_features = InputLayer(features, name="in_features")
            net_features_full = InputLayer(features_full,
                                           name="in_features_full")
            net_features_full = ReshapeLayer(
                net_features_full,
                (config.batch_size *
                 (config.out_seq_length + 1), config.dim_features),
                name="reshape_feature_full_1")
            if is_train:
                net_features = ReshapeLayer(
                    net_features,
                    (config.batch_size *
                     (config.out_seq_length + 1), config.dim_features),
                    name="reshape_feature_1")
            else:
                net_features = ReshapeLayer(net_features,
                                            (config.batch_size *
                                             (1), config.dim_features),
                                            name="reshape_feature_1")

            self.net_features_dim = 32
            net_features = DenseLayer(net_features,
                                      n_units=self.net_features_dim,
                                      act=tf.nn.relu,
                                      name='dense_features')
            net_features_full = DenseLayer(net_features_full,
                                           n_units=self.net_features_dim,
                                           act=tf.nn.relu,
                                           name='dense_features_full')
            # self.net_features = net_features

            net_wide_out = ConcatLayer([net_rnn_seq2seq, net_features],
                                       concat_dim=-1,
                                       name="concat_features")
            net_wide_out = DenseLayer(net_wide_out,
                                      n_units=1,
                                      act=tf.identity,
                                      name='dense2')

            if is_train:
                net_wide_out = ReshapeLayer(
                    net_wide_out,
                    (config.batch_size, config.out_seq_length + 1, 1),
                    name="reshape_out")
            else:
                net_wide_out = ReshapeLayer(net_wide_out,
                                            (config.batch_size, 1, 1),
                                            name="reshape_out")

        with tf.variable_scope(self.model_name + "_query", reuse=reuse) as vs:
            tl.layers.set_name_reuse(reuse)

            net_decode_query = InputLayer(self.query_decode_seq,
                                          name="decode_query")

            net_rnn_query = RNNLayer(
                net_decode_query,
                cell_fn=tf.contrib.rnn.BasicLSTMCell,
                cell_init_args={"forget_bias": 1.0},
                n_hidden=config.query_dim_hidden,
                initializer=tf.random_uniform_initializer(-0.1, 0.1),
                n_steps=config.out_seq_length,
                return_last=True,

                # return_last=False,
                # return_seq_2d=True,
                name="rnn_query")
            '''
            net_rnn_query = DynamicRNNLayer(
                net_decode_query,
                cell_fn=tf.contrib.rnn.BasicLSTMCell,
                cell_init_args={"forget_bias": 1.0},
                # n_hidden=config.query_dim_hidden,
                n_hidden=32,
                initializer=tf.random_uniform_initializer(-0.1, 0.1),
                return_last=True,
                # dropout=0.8,
                sequence_length=tl.layers.retrieve_seq_length_op(net_decode_query.outputs),
                # return_last=False,
                # return_seq_2d=True,
                name="rnn_query_dynamic"
            )
            '''

            net_rnn_query = ExpandDimsLayer(net_rnn_query,
                                            axis=1,
                                            name="rnn_query_expand")
            net_rnn_query = TileLayer(net_rnn_query,
                                      [1, config.out_seq_length, 1],
                                      name="rnn_query_tile")
            net_rnn_query = ReshapeLayer(
                net_rnn_query, (config.batch_size * config.out_seq_length,
                                config.query_dim_hidden),
                name="rnn_query_reshape")
            # net_rnn_query = ReshapeLayer(net_rnn_query, (config.batch_size * config.out_seq_length, 32), name="rnn_query_reshape")

            # self.net_rnn_query = net_rnn_query

            net_traffic_state = InputLayer(self.traffic_state,
                                           name="in_traffic_state")
            '''
            if is_train:
                net_rnn_traffic = ReshapeLayer(net_rnn_seq2seq, (config.batch_size, config.out_seq_length + 1, config.dim_hidden), name="reshape_traffic_q1")
                net_rnn_traffic.outputs = tf.slice(net_rnn_traffic.outputs, [0, 0, 0], [config.batch_size, config.out_seq_length, config.dim_hidden], name="slice_traffic_q")
                net_rnn_traffic = ReshapeLayer(net_rnn_traffic, (config.batch_size * config.out_seq_length, config.dim_hidden), name="reshape_traffic_q2")

                net_features_traffic = ReshapeLayer(net_features, (config.batch_size, config.out_seq_length + 1, self.net_features_dim), name="reshape_features_q1")
                net_features_traffic.outputs = tf.slice(net_features_traffic.outputs, [0, 0, 0], [config.batch_size, config.out_seq_length, self.net_features_dim], name="slice_features_q")
                net_features_traffic = ReshapeLayer(net_features_traffic, (config.batch_size * config.out_seq_length, self.net_features_dim), name="reshape_features_q2")

                net_query_out = ConcatLayer([net_rnn_traffic, net_features_traffic, net_rnn_query], concat_dim=-1, name="concat_traffic_query1")
                # net_query_out = ConcatLayer([net_rnn_traffic, net_rnn_query], concat_dim=-1, name="concat_traffic_query1")
            else:
            '''
            net_features_traffic = ReshapeLayer(
                net_features_full,
                (config.batch_size, config.out_seq_length + 1,
                 self.net_features_dim),
                name="reshape_features_q1")
            net_features_traffic.outputs = tf.slice(
                net_features_traffic.outputs, [0, 0, 0], [
                    config.batch_size, config.out_seq_length,
                    self.net_features_dim
                ],
                name="slice_features_q")
            net_features_traffic = ReshapeLayer(
                net_features_traffic,
                (config.batch_size * config.out_seq_length,
                 self.net_features_dim),
                name="reshape_features_q2")

            net_query_out = ConcatLayer(
                [net_traffic_state, net_features_traffic, net_rnn_query],
                concat_dim=-1,
                name="concat_traffic_query1")
            # net_rnn_traffic = ReshapeLayer(net_rnn_seq2seq, (config.batch_size, config.out_seq_length + 1, config.dim_hidden), name="reshape_traffic_q1")
            # net_rnn_traffic.outputs = tf.slice(net_rnn_traffic.outputs, [0, 0, 0], [config.batch_size, config.out_seq_length, config.dim_hidden], name="slice_traffic_q")
            # net_rnn_traffic = ReshapeLayer(net_rnn_traffic, (config.batch_size * config.out_seq_length, config.dim_hidden), name="reshape_traffic_q2")
            # net_query_out = ConcatLayer([net_rnn_traffic, net_features_traffic, net_rnn_query], concat_dim=-1, name="concat_traffic_query1")

            # net_out = DenseLayer(net_out, n_units=128, act=tf.nn.relu, name="dense_query1")
            # net_out = DenseLayer(net_out, n_units=64, act=tf.nn.relu, name="dense_query2")
            # net_query_out = DropoutLayer(net_query_out, keep=0.8, is_fix=True, is_train=is_train, name='drop_query3')
            net_query_out = DenseLayer(net_query_out,
                                       n_units=1,
                                       act=tf.identity,
                                       name="dense_query3")
            # net_out = ReshapeLayer(net_out, (config.batch_size, config.out_seq_length + 1, 1), name="reshape_out")
            # if is_train:
            net_query_out = ReshapeLayer(
                net_query_out, (config.batch_size, config.out_seq_length, 1),
                name="reshape_out")
            # else:
            #    net_out = ReshapeLayer(net_out, (config.batch_size, 1, 1), name="reshape_out")

            # TODO residual net
            '''
            if is_train:
                net_query_out.outputs = tf.add(
                    net_query_out.outputs,
                    tf.slice(net_wide_out.outputs, [0, 0, 0], [config.batch_size, config.out_seq_length, 1]),
                    name="res_add"
                )
            else:
            '''
            net_base_pred = InputLayer(self.base_pred, name="in_net_base_pred")
            net_query_out.outputs = tf.add(net_query_out.outputs,
                                           net_base_pred.outputs,
                                           name="res_add")

        return net_rnn_seq2seq, net_spatial_out, net_wide_out, net_rnn_query, net_query_out
Exemplo n.º 13
0
    def __get_network__(self,
                        encode_seq,
                        decode_seq,
                        is_train=True,
                        reuse=False):
        w_init = tf.random_normal_initializer(stddev=0.02)
        g_init = tf.random_normal_initializer(1., 0.02)

        with tf.variable_scope("seq2seq_model", reuse=reuse) as vs:
            tl.layers.set_name_reuse(reuse)
            net_encode = InputLayer(encode_seq, name='in_root')

            net_decode = InputLayer(decode_seq, name="decode")

            net_rnn = Seq2Seq(
                net_encode,
                net_decode,
                cell_fn=tf.contrib.rnn.BasicLSTMCell,
                n_hidden=config.dim_hidden,
                initializer=tf.random_uniform_initializer(-0.1, 0.1),
                encode_sequence_length=tl.layers.retrieve_seq_length_op(
                    net_encode.outputs),
                decode_sequence_length=tl.layers.retrieve_seq_length_op(
                    net_decode.outputs),
                initial_state_encode=None,
                # dropout=(0.8 if is_train else None),
                dropout=None,
                n_layer=1,
                return_seq_2d=True,
                name='seq2seq')
            # self.net_rnn_seq2seq = net_rnn
            net_rnn_seq2seq = net_rnn

            net_out_seq2seq = DenseLayer(net_rnn,
                                         n_units=1,
                                         act=tf.identity,
                                         name='dense2')
            if is_train:
                net_out_seq2seq = ReshapeLayer(
                    net_out_seq2seq,
                    (config.batch_size, config.out_seq_length + 1, 1),
                    name="reshape_out")
            else:
                net_out_seq2seq = ReshapeLayer(net_out_seq2seq,
                                               (config.batch_size, 1, 1),
                                               name="reshape_out")

            # net_out_seq2seq = net_out_seq2seq
            # net_out = DenseLayer(net_rnn, n_units=64, act=tf.identity, name='dense1')
            # net_out = DenseLayer(net_rnn, n_units=1, act=tf.identity, name='dense2')
            # net_out = ReshapeLayer(net_out, (config.batch_size, config.out_seq_length + 1, 1), name="reshape_out")

        with tf.variable_scope(self.model_name, reuse=reuse) as vs:
            tl.layers.set_name_reuse(reuse)
            net_encode_query = InputLayer(self.query_x, name='in_root_query')

            net_decode_query = InputLayer(self.query_decode_seq,
                                          name="decode_query")

            net_rnn_query = RNNLayer(
                net_decode_query,
                cell_fn=tf.contrib.rnn.BasicLSTMCell,
                cell_init_args={"forget_bias": 1.0},
                n_hidden=config.query_dim_hidden,
                initializer=tf.random_uniform_initializer(-0.1, 0.1),
                n_steps=config.out_seq_length,
                return_last=True,
                # return_last=False,
                # return_seq_2d=True,
                name="rnn_query")
            net_rnn_query = ExpandDimsLayer(net_rnn_query,
                                            axis=1,
                                            name="rnn_query_expand")
            net_rnn_query = TileLayer(net_rnn_query,
                                      [1, config.out_seq_length, 1],
                                      name="rnn_query_tile")
            net_rnn_query = ReshapeLayer(
                net_rnn_query, (config.batch_size * config.out_seq_length,
                                config.query_dim_hidden),
                name="rnn_query_reshape")

            net_traffic_state = InputLayer(self.traffic_state,
                                           name="in_traffic_state")

            if is_train:
                net_rnn_traffic = ReshapeLayer(
                    net_rnn_seq2seq,
                    (config.batch_size, config.out_seq_length + 1,
                     config.dim_hidden),
                    name="reshape_traffic_q1")
                net_rnn_traffic.outputs = tf.slice(
                    net_rnn_traffic.outputs, [0, 0, 0], [
                        config.batch_size, config.out_seq_length,
                        config.dim_hidden
                    ],
                    name="slice_traffic_q")
                net_rnn_traffic = ReshapeLayer(
                    net_rnn_traffic,
                    (config.batch_size * config.out_seq_length,
                     config.dim_hidden),
                    name="reshape_traffic_q2")
                net_out = ConcatLayer([net_rnn_traffic, net_rnn_query],
                                      concat_dim=-1,
                                      name="concat_traffic_query1")
            else:
                net_out = ConcatLayer([net_traffic_state, net_rnn_query],
                                      concat_dim=-1,
                                      name="concat_traffic_query2")

            # net_out = DenseLayer(net_out, n_units=128, act=tf.nn.relu, name="dense_query1")
            # net_out = DenseLayer(net_out, n_units=32, act=tf.nn.relu, name="dense_query2")
            net_out = DenseLayer(net_out,
                                 n_units=1,
                                 act=tf.identity,
                                 name="dense_query3")
            # net_out = ReshapeLayer(net_out, (config.batch_size, config.out_seq_length + 1, 1), name="reshape_out")
            # if is_train:
            net_out = ReshapeLayer(
                net_out, (config.batch_size, config.out_seq_length, 1),
                name="reshape_out")
            # else:
            #    net_out = ReshapeLayer(net_out, (config.batch_size, 1, 1), name="reshape_out")
        return net_rnn_seq2seq, net_out_seq2seq, net_rnn_query, net_out
Exemplo n.º 14
0
    def __get_network__(self,
                        encode_seq,
                        neighbour_seq,
                        decode_seq,
                        is_train=True,
                        reuse=False):
        w_init = tf.random_normal_initializer(stddev=0.02)
        g_init = tf.random_normal_initializer(1., 0.02)

        with tf.variable_scope(self.model_name, reuse=reuse) as vs:
            tl.layers.set_name_reuse(reuse)
            inputs_x_root = InputLayer(encode_seq, name='in_root')
            inputs_x_nbor = InputLayer(neighbour_seq, name="in_neighbour")

            # encoding neighbour graph information
            n = ReshapeLayer(inputs_x_nbor,
                             (config.batch_size * config.in_seq_length,
                              config.num_neighbour), "reshape1")
            n.outputs = tf.expand_dims(n.outputs, axis=-1)
            n = Conv1d(n,
                       4,
                       4,
                       1,
                       act=tf.identity,
                       padding='SAME',
                       W_init=w_init,
                       name='conv1')
            n = BatchNormLayer(n,
                               act=tf.nn.relu,
                               is_train=is_train,
                               gamma_init=g_init,
                               name='bn1')
            n = MaxPool1d(n, 2, 2, padding='valid', name='maxpool1')
            n = FlattenLayer(n, name="flatten1")
            n = ReshapeLayer(n, (config.batch_size, config.in_seq_length, -1),
                             name="reshape1_back")

            net_encode = ConcatLayer([inputs_x_root, n],
                                     concat_dim=-1,
                                     name="encode")
            net_decode = InputLayer(decode_seq, name="decode")

            net_rnn = Seq2Seq(
                net_encode,
                net_decode,
                cell_fn=tf.contrib.rnn.BasicLSTMCell,
                n_hidden=config.dim_hidden,
                initializer=tf.random_uniform_initializer(-0.1, 0.1),
                encode_sequence_length=tl.layers.retrieve_seq_length_op(
                    net_encode.outputs),
                decode_sequence_length=tl.layers.retrieve_seq_length_op(
                    net_decode.outputs),
                initial_state_encode=None,
                # dropout=(0.8 if is_train else None),
                dropout=None,
                n_layer=1,
                return_seq_2d=True,
                name='seq2seq')
            # net_out = DenseLayer(net_rnn, n_units=64, act=tf.identity, name='dense1')
            net_out = DenseLayer(net_rnn,
                                 n_units=1,
                                 act=tf.identity,
                                 name='dense2')
            if is_train:
                net_out = ReshapeLayer(
                    net_out, (config.batch_size, config.out_seq_length + 1, 1),
                    name="reshape_out")
            else:
                net_out = ReshapeLayer(net_out, (config.batch_size, 1, 1),
                                       name="reshape_out")

            self.net_rnn = net_rnn

            return net_out
Exemplo n.º 15
0
def generator(inputs, is_train=True, reuse=False):
    img_size = CFG.img_size
    s2, s4, s8, s16 = [int(img_size / i) for i in [2, 4, 8, 16]]
    gfs = 64
    channels = CFG.channels
    batch_size = CFG.batch_size

    W_init = tf.random_normal_initializer(stddev=0.02)
    gamma_init = tf.random_normal_initializer(1., 0.02)

    with tf.variable_scope('generator', reuse=reuse):
        tl.layers.set_name_reuse(reuse)

        g = InputLayer(inputs, name='g/inputs')
        g = DenseLayer(g,
                       gfs * 8 * s16 * s16,
                       W_init=W_init,
                       act=tl.act.identity,
                       name='g/fc1')
        g = ReshapeLayer(g, shape=(-1, s16, s16, gfs * 8), name='g/reshape2')
        g = BatchNormLayer(g,
                           act=tf.nn.relu,
                           is_train=is_train,
                           gamma_init=gamma_init,
                           name='g/bn3')

        g = DeConv2d(g,
                     gfs * 4, (5, 5),
                     out_size=(s8, s8),
                     strides=(2, 2),
                     batch_size=batch_size,
                     act=None,
                     W_init=W_init,
                     name='g/dconv4')
        g = BatchNormLayer(g,
                           act=tf.nn.relu,
                           is_train=is_train,
                           gamma_init=gamma_init,
                           name='g/bn5')

        g = DeConv2d(g,
                     gfs * 2, (5, 5),
                     out_size=(s4, s4),
                     strides=(2, 2),
                     batch_size=batch_size,
                     act=None,
                     W_init=W_init,
                     name='g/dconv6')
        g = BatchNormLayer(g,
                           act=tf.nn.relu,
                           is_train=is_train,
                           gamma_init=gamma_init,
                           name='g/bn7')

        g = DeConv2d(g,
                     gfs, (5, 5),
                     out_size=(s2, s2),
                     strides=(2, 2),
                     batch_size=batch_size,
                     act=None,
                     W_init=W_init,
                     name='g/dconv8')
        g = BatchNormLayer(g,
                           act=tf.nn.relu,
                           is_train=is_train,
                           gamma_init=gamma_init,
                           name='g/bn9')

        g = DeConv2d(g,
                     channels, (5, 5),
                     out_size=(img_size, img_size),
                     strides=(2, 2),
                     batch_size=batch_size,
                     act=None,
                     W_init=W_init,
                     name='g/dconv10')

        logits = g.outputs
        g.outputs = tf.nn.tanh(g.outputs)
    return g, logits