def generator_single_chart(self, noise): with tf.variable_scope("generator", reuse=tf.AUTO_REUSE): depths = self.config.depths spatial = self.config.spatial output = layers.linear.Linear('input_FC', self.config.latent_vec_dim, (spatial[0]**2)*depths[0], noise) output = tf.nn.relu(output) output = tf.reshape(output, [-1, depths[0], spatial[0], spatial[0]]) # convolution layers output = Conv2D('conv_in_{0}-{1}_{2}'.format(depths[0], depths[0], spatial[0]), depths[0], depths[0], 3, output, spatial[0]) # N x dim x 4 x 4 output = tf.nn.relu(output) for ind, depth in enumerate(depths[:-1]): output, _, _ = layers.conv_block.GenConvBlock('conv__{3}_{0}-{1}_{2}'.format(depth, depths[ind+1], spatial[ind], ind), depth, depths[ind+1], 3, output, spatial[ind]) output = Conv2D('conv_out', depths[ind+1], self.config.depth_dim, 1, output, self.config.spatial_dim) # enforce toric symmetry output = toric_symmetry(output, self.config.spatial_dim) # reduce mean if self.config.normalize_charts: data_mean, _ = tf.nn.moments(output, axes=[2, 3], keep_dims=True) output = tf.subtract(output, data_mean) return output
def generator(self, noise): if self.config.model == 'single_chart': return self.generator_single_chart(noise) else: with tf.variable_scope("generator", reuse=tf.AUTO_REUSE): depths = self.config.depths spatial = self.config.spatial output = layers.linear.Linear('input_FC', self.config.latent_vec_dim, (spatial[0]**2)*depths[0], noise) output = tf.nn.relu(output) output = tf.reshape(output, [-1, depths[0], spatial[0], spatial[0]]) # convolution layers output = Conv2D('conv_in_{0}-{1}_{2}'.format(depths[0], depths[0], spatial[0]), depths[0], depths[0], 3, output, spatial[0]) # N x dim x 4 x 4 output = tf.nn.relu(output) for ind, depth in enumerate(depths[:-1]): output, _, _ = layers.conv_block.GenConvBlock('conv_{3}_{0}-{1}_{2}'.format(depth, depths[ind+1], spatial[ind], ind), depth, depths[ind+1], 3, output, spatial[ind]) output = Conv2D('conv_out', depths[ind+1], self.config.depth_dim, 1, output, self.config.spatial_dim) # enforce toric symmetry output = toric_symmetry(output, self.config.spatial_dim) # landmark consistency layer def project_ST_op(): if self.config.gamma_decay is not None: gamma = tf.multiply(tf.pow(self.config.gamma_decay,tf.cast(tf.subtract(self.cur_epoch_tensor,self.config.kick_projection),dtype='float32')),self.config.gamma) else: gamma = self.config.gamma op1 = project_ST(self, output, gamma) with tf.control_dependencies([op1]): return tf.identity(op1) # the landmark consistency layer kicks in from a certain given epoch output = tf.cond(tf.greater(self.cur_epoch_tensor, self.config.kick_projection), project_ST_op, lambda: output) # reduce mean if self.config.normalize_charts: data_mean, _ = tf.nn.moments(output, axes=[2, 3], keep_dims=True) output = tf.subtract(output, data_mean) return output