def discriminator_net(self, lx, reuse=False):
        """
        :param lx: the images from lens
        :param reuse:
        :return:
        """
        # layer height, width
        s_h, s_w, _ = self.input_size
        s_h2, s_w2 = utl.get_out_size(s_h, 2), utl.get_out_size(s_w, 2)
        s_h4, s_w4 = utl.get_out_size(s_h2, 2), utl.get_out_size(s_w2, 2)
        s_h8, s_w8 = utl.get_out_size(s_h4, 2), utl.get_out_size(s_w4, 2)
        s_h16, s_w16 = utl.get_out_size(s_h8, 2), utl.get_out_size(s_w8, 2)
        with tf.variable_scope('discriminator', reuse=reuse):
            h0 = ops.conv2d(lx, output_num=self.size, name='d_h0', reuse=reuse)
            h0 = ops.lrelu(h0, name='d_l0')

            h1 = ops.conv2d(h0, output_num=self.size * 2, name='d_h1', reuse=reuse)
            h1 = ops.batch_normalizer(h1, name='d_bn1', reuse=reuse)
            h1 = ops.lrelu(h1, name='d_l1')

            h2 = ops.conv2d(h1, output_num=self.size * 4, name='d_h2', reuse=reuse)
            h2 = ops.batch_normalizer(h2, name='d_bn2', reuse=reuse)
            h2 = ops.lrelu(h2, name='d_l2')

            h3 = ops.conv2d(h2, output_num=self.size * 8, name='d_h3', reuse=reuse)
            h3 = ops.batch_normalizer(h3, name='d_bn3', reuse=reuse)
            h3 = ops.lrelu(h3, name='d_l3')

            h4 = tf.reshape(h3, [self.batch_size, s_h16 * s_w16 * self.size * 8])

            h4 = ops.full_connect(h4, output_num=1, name='d_full', reuse=reuse)
            return h4
Esempio n. 2
0
    def lens_net(self, x, reuse=False):
        """
        :param x: input real data x
        :param reuse:
        :return:lens x: lx
        """
        with tf.variable_scope('lens', reuse=reuse):
            h0 = ops.conv2d(x,
                            output_num=self.size,
                            stride=1,
                            filter_size=3,
                            name='l_h0')
            h0 = ops.lrelu(h0, name='l_l0')

            h1 = ops.res_block3_3(h0, name='l_res_1', reuse=reuse)
            h2 = ops.res_block3_3(h1, name='l_res_2', reuse=reuse)

            h3 = ops.conv2d(h2,
                            output_num=3,
                            stride=1,
                            filter_size=3,
                            name='l_h4')
            h3 = ops.lrelu(h3, leak=0.4, name='l_l4')
            h3 = h3 + x
            return h3