コード例 #1
0
    def generator_single_chart(self, noise):

        with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
            depths = self.config.depths
            spatial = self.config.spatial
            output = layers.linear.Linear('input_FC', self.config.latent_vec_dim, (spatial[0]**2)*depths[0], noise)
            output = tf.nn.relu(output)
            output = tf.reshape(output, [-1, depths[0], spatial[0], spatial[0]])

            # convolution layers
            output = Conv2D('conv_in_{0}-{1}_{2}'.format(depths[0], depths[0], spatial[0]),
                            depths[0], depths[0], 3, output, spatial[0]) # N x dim x 4 x 4
            output = tf.nn.relu(output)
            for ind, depth in enumerate(depths[:-1]):
                output, _, _ = layers.conv_block.GenConvBlock('conv__{3}_{0}-{1}_{2}'.format(depth, depths[ind+1], spatial[ind], ind),
                                                              depth, depths[ind+1], 3, output, spatial[ind])
            output = Conv2D('conv_out', depths[ind+1], self.config.depth_dim, 1, output, self.config.spatial_dim)

            # enforce toric symmetry
            output = toric_symmetry(output, self.config.spatial_dim)

            # reduce mean
            if self.config.normalize_charts:
                data_mean, _ = tf.nn.moments(output, axes=[2, 3], keep_dims=True)
                output = tf.subtract(output, data_mean)

        return output
コード例 #2
0
    def generator(self, noise):

        if self.config.model == 'single_chart':
            return self.generator_single_chart(noise)
        else:
            with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
                depths = self.config.depths
                spatial = self.config.spatial
                output = layers.linear.Linear('input_FC', self.config.latent_vec_dim, (spatial[0]**2)*depths[0], noise)
                output = tf.nn.relu(output)
                output = tf.reshape(output, [-1, depths[0], spatial[0], spatial[0]])

                # convolution layers
                output = Conv2D('conv_in_{0}-{1}_{2}'.format(depths[0], depths[0], spatial[0]),
                                depths[0], depths[0], 3, output, spatial[0]) # N x dim x 4 x 4
                output = tf.nn.relu(output)
                for ind, depth in enumerate(depths[:-1]):
                    output, _, _ = layers.conv_block.GenConvBlock('conv_{3}_{0}-{1}_{2}'.format(depth, depths[ind+1], spatial[ind], ind),
                                                                  depth, depths[ind+1], 3, output, spatial[ind])
                output = Conv2D('conv_out', depths[ind+1], self.config.depth_dim, 1, output, self.config.spatial_dim)

                # enforce toric symmetry
                output = toric_symmetry(output, self.config.spatial_dim)

                # landmark consistency layer
                def project_ST_op():
                    if self.config.gamma_decay is not None:
                        gamma = tf.multiply(tf.pow(self.config.gamma_decay,tf.cast(tf.subtract(self.cur_epoch_tensor,self.config.kick_projection),dtype='float32')),self.config.gamma)
                    else:
                        gamma = self.config.gamma
                    op1 = project_ST(self, output, gamma)
                    with tf.control_dependencies([op1]):
                        return tf.identity(op1)

                # the landmark consistency layer kicks in from a certain given epoch
                output = tf.cond(tf.greater(self.cur_epoch_tensor, self.config.kick_projection),
                                 project_ST_op,
                                 lambda: output)

                # reduce mean
                if self.config.normalize_charts:
                    data_mean, _ = tf.nn.moments(output, axes=[2, 3], keep_dims=True)
                    output = tf.subtract(output, data_mean)

            return output